Skip to content

Commit

Permalink
Bret wants to be able to validate the port before connect() is called
Browse files Browse the repository at this point in the history
  • Loading branch information
graebm committed Dec 14, 2023
1 parent eb29a1a commit 965af4e
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 85 deletions.
16 changes: 16 additions & 0 deletions include/aws/io/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,22 @@ AWS_IO_API int aws_socket_get_error(struct aws_socket *socket);
*/
AWS_IO_API bool aws_socket_is_open(struct aws_socket *socket);

/**
* Raises AWS_IO_SOCKET_INVALID_ADDRESS and logs an error if connecting to this port is illegal.
* For example, port must be in range 1-65535 to connect with IPv4.
* These port values would fail eventually in aws_socket_connect(),
* but you can use this function to validate earlier.
*/
AWS_IO_API int aws_socket_validate_port_for_connect(uint32_t port, enum aws_socket_domain domain);

/**
* Raises AWS_IO_SOCKET_INVALID_ADDRESS and logs an error if binding to this port is illegal.
* For example, port must in range 0-65535 to bind with IPv4.
* These port values would fail eventually in aws_socket_bind(),
* but you can use this function to validate earlier.
*/
AWS_IO_API int aws_socket_validate_port_for_bind(uint32_t port, enum aws_socket_domain domain);

/**
* Assigns a random address (UUID) for use with AWS_SOCKET_LOCAL (Unix Domain Sockets).
* For use in internal tests only.
Expand Down
36 changes: 8 additions & 28 deletions source/posix/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -631,20 +631,21 @@ int aws_socket_connect(
return AWS_OP_ERR;
}

if (aws_socket_validate_port_for_connect(remote_endpoint->port, socket->options.domain)) {
return AWS_OP_ERR;
}

struct socket_address address;
AWS_ZERO_STRUCT(address);
socklen_t sock_size = 0;
int pton_err = 1;
bool port_is_bad = false;
if (socket->options.domain == AWS_SOCKET_IPV4) {
pton_err = inet_pton(AF_INET, remote_endpoint->address, &address.sock_addr_types.addr_in.sin_addr);
port_is_bad = remote_endpoint->port > UINT16_MAX;
address.sock_addr_types.addr_in.sin_port = htons((uint16_t)remote_endpoint->port);
address.sock_addr_types.addr_in.sin_family = AF_INET;
sock_size = sizeof(address.sock_addr_types.addr_in);
} else if (socket->options.domain == AWS_SOCKET_IPV6) {
pton_err = inet_pton(AF_INET6, remote_endpoint->address, &address.sock_addr_types.addr_in6.sin6_addr);
port_is_bad = remote_endpoint->port > UINT16_MAX;
address.sock_addr_types.addr_in6.sin6_port = htons((uint16_t)remote_endpoint->port);
address.sock_addr_types.addr_in6.sin6_family = AF_INET6;
sock_size = sizeof(address.sock_addr_types.addr_in6);
Expand Down Expand Up @@ -676,17 +677,6 @@ int aws_socket_connect(
return aws_raise_error(s_convert_pton_error(pton_err, errno_value));
}

if (port_is_bad) {
AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: illegal port value, too high for this socket type. %s:%u",
(void *)socket,
socket->io_handle.data.fd,
remote_endpoint->address,
remote_endpoint->port);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}

AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: connecting to endpoint %s:%u.",
Expand Down Expand Up @@ -809,6 +799,10 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
return AWS_OP_ERR;
}

if (aws_socket_validate_port_for_bind(local_endpoint->port, socket->options.domain)) {
return AWS_OP_ERR;
}

AWS_LOGF_INFO(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: binding to %s:%u.",
Expand All @@ -821,16 +815,13 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
AWS_ZERO_STRUCT(address);
socklen_t sock_size = 0;
int pton_err = 1;
bool port_is_bad = false;
if (socket->options.domain == AWS_SOCKET_IPV4) {
pton_err = inet_pton(AF_INET, local_endpoint->address, &address.sock_addr_types.addr_in.sin_addr);
port_is_bad = local_endpoint->port > UINT16_MAX;
address.sock_addr_types.addr_in.sin_port = htons((uint16_t)local_endpoint->port);
address.sock_addr_types.addr_in.sin_family = AF_INET;
sock_size = sizeof(address.sock_addr_types.addr_in);
} else if (socket->options.domain == AWS_SOCKET_IPV6) {
pton_err = inet_pton(AF_INET6, local_endpoint->address, &address.sock_addr_types.addr_in6.sin6_addr);
port_is_bad = local_endpoint->port > UINT16_MAX;
address.sock_addr_types.addr_in6.sin6_port = htons((uint16_t)local_endpoint->port);
address.sock_addr_types.addr_in6.sin6_family = AF_INET6;
sock_size = sizeof(address.sock_addr_types.addr_in6);
Expand Down Expand Up @@ -862,17 +853,6 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
return aws_raise_error(s_convert_pton_error(pton_err, errno_value));
}

if (port_is_bad) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: illegal port value, too high for this socket type. %s:%u",
(void *)socket,
socket->io_handle.data.fd,
local_endpoint->address,
local_endpoint->port);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}

if (bind(socket->io_handle.data.fd, (struct sockaddr *)&address.sock_addr_types, sock_size) != 0) {
int errno_value = errno; /* Always cache errno before potential side-effect */
AWS_LOGF_ERROR(
Expand Down
75 changes: 75 additions & 0 deletions source/socket_shared.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#include <aws/io/socket.h>

#include <aws/io/logging.h>

/* common validation for connect() and bind() */
static int s_socket_validate_port_for_domain(uint32_t port, enum aws_socket_domain domain) {
switch (domain) {
case AWS_SOCKET_IPV4:
case AWS_SOCKET_IPV6:
if (port > UINT16_MAX) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"Invalid port=%u for %s. Cannot exceed 65535",
port,
domain == AWS_SOCKET_IPV4 ? "IPv4" : "IPv6");
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
break;

case AWS_SOCKET_LOCAL:
/* port is ignored */
break;

case AWS_SOCKET_VSOCK:
/* any 32bit port is legal */
break;

default:
AWS_LOGF_ERROR(AWS_LS_IO_SOCKET, "Cannot validate port for unknown domain=%d", domain);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
return AWS_OP_SUCCESS;
}

int aws_socket_validate_port_for_connect(uint32_t port, enum aws_socket_domain domain) {
if (s_socket_validate_port_for_domain(port, domain)) {
return AWS_OP_ERR;
}

/* additional validation */
switch (domain) {
case AWS_SOCKET_IPV4:
case AWS_SOCKET_IPV6:
if (port == 0) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"Invalid port=%u for %s connections. Must use 1-65535",
port,
domain == AWS_SOCKET_IPV4 ? "IPv4" : "IPv6");
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
break;

case AWS_SOCKET_VSOCK:
if (port == -1U) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET, "Invalid port for VSOCK connections. Cannot use VMADDR_PORT_ANY (-1U).");
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
break;

default:
/* no extra validation */
break;
}
return AWS_OP_SUCCESS;
}

int aws_socket_validate_port_for_bind(uint32_t port, enum aws_socket_domain domain) {
return s_socket_validate_port_for_domain(port, domain);
}
67 changes: 10 additions & 57 deletions source/windows/iocp/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,11 @@ int aws_socket_connect(
return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE);
}
}

if (aws_socket_validate_port_for_connect(remote_endpoint->port, socket->option.domain)) {
return AWS_OP_ERR;
}

return socket_impl->vtable->connect(socket, remote_endpoint, event_loop, on_connection_result, user_data);
}

Expand All @@ -457,6 +462,11 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
socket->state = ERRORED;
return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE);
}

if (aws_socket_validate_port_for_bind(local_endpoint->port, socket->options.domain)) {
return AWS_OP_ERR;
}

struct iocp_socket *socket_impl = socket->impl;
return socket_impl->vtable->bind(socket, local_endpoint);
}
Expand Down Expand Up @@ -1058,17 +1068,6 @@ static int s_ipv4_stream_connect(
return aws_raise_error(aws_err);
}

if (remote_endpoint->port > UINT16_MAX) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p handle=%p: illegal port value, too high for IPV4. %s:%u",
(void *)socket,
(void *)socket->io_handle.data.handle,
remote_endpoint->address,
remote_endpoint->port);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}

AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p handle=%p: connecting to endpoint %s:%u.",
Expand Down Expand Up @@ -1136,17 +1135,6 @@ static int s_ipv6_stream_connect(
return aws_raise_error(aws_err);
}

if (remote_endpoint->port > UINT16_MAX) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p handle=%p: illegal port value, too high for IPV6. %s:%u",
(void *)socket,
(void *)socket->io_handle.data.handle,
remote_endpoint->address,
remote_endpoint->port);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}

AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p handle=%p: connecting to endpoint %s:%u.",
Expand Down Expand Up @@ -1356,10 +1344,6 @@ static int s_ipv4_dgram_connect(
return aws_raise_error(aws_err);
}

if (remote_endpoint->port > UINT16_MAX) {
socket->state = ERRORED;
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
addr_in.sin_port = htons((uint16_t)remote_endpoint->port);
addr_in.sin_family = AF_INET;

Expand Down Expand Up @@ -1387,17 +1371,6 @@ static int s_ipv6_dgram_connect(
return aws_raise_error(aws_err);
}

if (remote_endpoint->port > UINT16_MAX) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p handle=%p: illegal port value, too high for IPV6. %s:%u",
(void *)socket,
(void *)socket->io_handle.data.handle,
remote_endpoint->address,
remote_endpoint->port);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}

addr_in6.sin6_port = htons((uint16_t)remote_endpoint->port);
addr_in6.sin6_family = AF_INET6;

Expand Down Expand Up @@ -1468,11 +1441,6 @@ static int s_ipv4_stream_bind(struct aws_socket *socket, const struct aws_socket
return aws_raise_error(aws_err);
}

if (local_endpoint->port > UINT16_MAX) {
socket->state = ERRORED;
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}

addr_in.sin_port = htons((uint16_t)local_endpoint->port);
addr_in.sin_family = AF_INET;

Expand All @@ -1490,11 +1458,6 @@ static int s_ipv6_stream_bind(struct aws_socket *socket, const struct aws_socket
return aws_raise_error(aws_err);
}

if (local_endpoint->port > UINT16_MAX) {
socket->state = ERRORED;
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}

addr_in6.sin6_port = htons((uint16_t)local_endpoint->port);
addr_in6.sin6_family = AF_INET6;

Expand Down Expand Up @@ -1546,11 +1509,6 @@ static int s_ipv4_dgram_bind(struct aws_socket *socket, const struct aws_socket_
return aws_raise_error(aws_err);
}

if (local_endpoint->port > UINT16_MAX) {
socket->state = ERRORED;
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}

addr_in.sin_port = htons((uint16_t)local_endpoint->port);
addr_in.sin_family = AF_INET;

Expand All @@ -1568,11 +1526,6 @@ static int s_ipv6_dgram_bind(struct aws_socket *socket, const struct aws_socket_
return aws_raise_error(aws_err);
}

if (local_endpoint->port > UINT16_MAX) {
socket->state = ERRORED;
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}

addr_in6.sin6_port = htons((uint16_t)local_endpoint->port);
addr_in6.sin6_family = AF_INET6;

Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ add_net_test_case(cleanup_before_connect_or_timeout_doesnt_explode)
add_test_case(cleanup_in_accept_doesnt_explode)
add_test_case(cleanup_in_write_cb_doesnt_explode)
add_test_case(sock_write_cb_is_async)
add_test_case(socket_validate_port)

if(WIN32)
add_test_case(local_socket_pipe_connected_race)
Expand Down
53 changes: 53 additions & 0 deletions tests/socket_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -1752,3 +1752,56 @@ static int s_local_socket_pipe_connected_race(struct aws_allocator *allocator, v
AWS_TEST_CASE(local_socket_pipe_connected_race, s_local_socket_pipe_connected_race)

#endif

static int s_test_socket_validate_port(struct aws_allocator *allocator, void *ctx) {
(void)allocator;
(void)ctx;

/* IPv4 - 16bit port, only bind can use 0 */
ASSERT_SUCCESS(aws_socket_validate_port_for_connect(80, AWS_SOCKET_IPV4));
ASSERT_SUCCESS(aws_socket_validate_port_for_bind(80, AWS_SOCKET_IPV4));

ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_connect(0, AWS_SOCKET_IPV4));
ASSERT_SUCCESS(aws_socket_validate_port_for_bind(0, AWS_SOCKET_IPV4));

ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_connect(0xFFFFFFFF, AWS_SOCKET_IPV4));
ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_bind(0xFFFFFFFF, AWS_SOCKET_IPV4));

/* IPv6 - 16bit port, only bind can use 0 */
ASSERT_SUCCESS(aws_socket_validate_port_for_connect(80, AWS_SOCKET_IPV6));
ASSERT_SUCCESS(aws_socket_validate_port_for_bind(80, AWS_SOCKET_IPV6));

ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_connect(0, AWS_SOCKET_IPV6));
ASSERT_SUCCESS(aws_socket_validate_port_for_bind(0, AWS_SOCKET_IPV6));

ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_connect(0xFFFFFFFF, AWS_SOCKET_IPV6));
ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_bind(0xFFFFFFFF, AWS_SOCKET_IPV6));

/* VSOCK - 32bit port, only bind can use VMADDR_PORT_ANY (-1U) */
ASSERT_SUCCESS(aws_socket_validate_port_for_connect(80, AWS_SOCKET_VSOCK));
ASSERT_SUCCESS(aws_socket_validate_port_for_bind(80, AWS_SOCKET_VSOCK));

ASSERT_SUCCESS(aws_socket_validate_port_for_connect(0, AWS_SOCKET_VSOCK));
ASSERT_SUCCESS(aws_socket_validate_port_for_bind(0, AWS_SOCKET_VSOCK));

ASSERT_SUCCESS(aws_socket_validate_port_for_connect(0x7FFFFFFF, AWS_SOCKET_VSOCK));
ASSERT_SUCCESS(aws_socket_validate_port_for_bind(0x7FFFFFFF, AWS_SOCKET_VSOCK));

ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_connect(-1U, AWS_SOCKET_VSOCK));
ASSERT_SUCCESS(aws_socket_validate_port_for_bind(-1U, AWS_SOCKET_VSOCK));

/* LOCAL - ignores port */
ASSERT_SUCCESS(aws_socket_validate_port_for_connect(0, AWS_SOCKET_LOCAL));
ASSERT_SUCCESS(aws_socket_validate_port_for_bind(0, AWS_SOCKET_LOCAL));
ASSERT_SUCCESS(aws_socket_validate_port_for_connect(1, AWS_SOCKET_LOCAL));
ASSERT_SUCCESS(aws_socket_validate_port_for_bind(1, AWS_SOCKET_LOCAL));
ASSERT_SUCCESS(aws_socket_validate_port_for_connect(-1U, AWS_SOCKET_LOCAL));
ASSERT_SUCCESS(aws_socket_validate_port_for_bind(-1U, AWS_SOCKET_LOCAL));

/* invalid domain should fail */
ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_connect(80, (enum aws_socket_domain)(-1)));
ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_bind(80, (enum aws_socket_domain)(-1)));

return 0;
}
AWS_TEST_CASE(socket_validate_port, s_test_socket_validate_port)

0 comments on commit 965af4e

Please sign in to comment.