Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change port from uint16_t to uint32_t, to support VSOCK #613

Merged
merged 3 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/aws/io/channel_bootstrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ struct aws_server_bootstrap {
struct aws_socket_channel_bootstrap_options {
struct aws_client_bootstrap *bootstrap;
const char *host_name;
uint16_t port;
uint32_t port;
const struct aws_socket_options *socket_options;
const struct aws_tls_connection_options *tls_options;
aws_client_bootstrap_on_channel_event_fn *creation_callback;
Expand Down Expand Up @@ -208,7 +208,7 @@ struct aws_socket_channel_bootstrap_options {
struct aws_server_socket_channel_bootstrap_options {
struct aws_server_bootstrap *bootstrap;
const char *host_name;
uint16_t port;
uint32_t port;
const struct aws_socket_options *socket_options;
const struct aws_tls_connection_options *tls_options;
aws_server_bootstrap_on_accept_channel_setup_fn *incoming_callback;
Expand Down
2 changes: 1 addition & 1 deletion include/aws/io/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ typedef void(aws_socket_on_readable_fn)(struct aws_socket *socket, int error_cod
#endif
struct aws_socket_endpoint {
char address[AWS_ADDRESS_MAX_LEN];
uint16_t port;
uint32_t port;
};

struct aws_socket {
Expand Down
12 changes: 6 additions & 6 deletions source/channel_bootstrap.c
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ struct client_connection_args {
aws_client_bootstrap_on_channel_event_fn *shutdown_callback;
struct client_channel_data channel_data;
struct aws_socket_options outgoing_options;
uint16_t outgoing_port;
uint32_t outgoing_port;
struct aws_string *host_name;
void *user_data;
uint8_t addresses_count;
Expand Down Expand Up @@ -764,14 +764,14 @@ int aws_client_bootstrap_new_socket_channel(struct aws_socket_channel_bootstrap_
}

const char *host_name = options->host_name;
uint16_t port = options->port;
uint32_t port = options->port;

AWS_LOGF_TRACE(
AWS_LS_IO_CHANNEL_BOOTSTRAP,
"id=%p: attempting to initialize a new client channel to %s:%d",
"id=%p: attempting to initialize a new client channel to %s:%u",
(void *)bootstrap,
host_name,
(int)port);
port);

aws_ref_count_init(
&client_connection_args->ref_count,
Expand Down Expand Up @@ -1363,10 +1363,10 @@ struct aws_socket *aws_server_bootstrap_new_socket_listener(
AWS_LOGF_INFO(
AWS_LS_IO_CHANNEL_BOOTSTRAP,
"id=%p: attempting to initialize a new "
"server socket listener for %s:%d",
"server socket listener for %s:%u",
(void *)bootstrap_options->bootstrap,
bootstrap_options->host_name,
(int)bootstrap_options->port);
bootstrap_options->port);

aws_ref_count_init(
&server_connection_args->ref_count,
Expand Down
73 changes: 45 additions & 28 deletions source/posix/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -340,18 +340,7 @@ static int s_update_local_endpoint(struct aws_socket *socket) {
} else if (address.ss_family == AF_VSOCK) {
struct sockaddr_vm *s = (struct sockaddr_vm *)&address;

/* VSOCK port is 32bit, but aws_socket_endpoint.port is only 16bit.
* Hopefully this isn't an issue, since users can only pass in 16bit values.
* But if it becomes an issue, we'll need to make aws_socket_endpoint more flexible */
if (s->svm_port > UINT16_MAX) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: aws_socket_endpoint can't deal with VSOCK port > UINT16_MAX",
(void *)socket,
socket->io_handle.data.fd);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
tmp_endpoint.port = (uint16_t)s->svm_port;
tmp_endpoint.port = s->svm_port;

snprintf(tmp_endpoint.address, sizeof(tmp_endpoint.address), "%" PRIu32, s->svm_cid);
return AWS_OP_SUCCESS;
Expand Down Expand Up @@ -646,14 +635,17 @@ int aws_socket_connect(
AWS_ZERO_STRUCT(address);
socklen_t sock_size = 0;
int pton_err = 1;
bool port_is_bad = false;
graebm marked this conversation as resolved.
Show resolved Hide resolved
if (socket->options.domain == AWS_SOCKET_IPV4) {
pton_err = inet_pton(AF_INET, remote_endpoint->address, &address.sock_addr_types.addr_in.sin_addr);
address.sock_addr_types.addr_in.sin_port = htons(remote_endpoint->port);
port_is_bad = remote_endpoint->port > UINT16_MAX;
address.sock_addr_types.addr_in.sin_port = htons((uint16_t)remote_endpoint->port);
address.sock_addr_types.addr_in.sin_family = AF_INET;
sock_size = sizeof(address.sock_addr_types.addr_in);
} else if (socket->options.domain == AWS_SOCKET_IPV6) {
pton_err = inet_pton(AF_INET6, remote_endpoint->address, &address.sock_addr_types.addr_in6.sin6_addr);
address.sock_addr_types.addr_in6.sin6_port = htons(remote_endpoint->port);
port_is_bad = remote_endpoint->port > UINT16_MAX;
address.sock_addr_types.addr_in6.sin6_port = htons((uint16_t)remote_endpoint->port);
address.sock_addr_types.addr_in6.sin6_family = AF_INET6;
sock_size = sizeof(address.sock_addr_types.addr_in6);
} else if (socket->options.domain == AWS_SOCKET_LOCAL) {
Expand All @@ -664,7 +656,7 @@ int aws_socket_connect(
} else if (socket->options.domain == AWS_SOCKET_VSOCK) {
pton_err = parse_cid(remote_endpoint->address, &address.sock_addr_types.vm_addr.svm_cid);
address.sock_addr_types.vm_addr.svm_family = AF_VSOCK;
address.sock_addr_types.vm_addr.svm_port = (unsigned int)remote_endpoint->port;
address.sock_addr_types.vm_addr.svm_port = remote_endpoint->port;
sock_size = sizeof(address.sock_addr_types.vm_addr);
#endif
} else {
Expand All @@ -676,21 +668,32 @@ int aws_socket_connect(
int errno_value = errno; /* Always cache errno before potential side-effect */
AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: failed to parse address %s:%d.",
"id=%p fd=%d: failed to parse address %s:%u.",
(void *)socket,
socket->io_handle.data.fd,
remote_endpoint->address,
(int)remote_endpoint->port);
remote_endpoint->port);
return aws_raise_error(s_convert_pton_error(pton_err, errno_value));
}

if (port_is_bad) {
AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: illegal port value, too high for this socket type. %s:%u",
(void *)socket,
socket->io_handle.data.fd,
remote_endpoint->address,
remote_endpoint->port);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}

AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: connecting to endpoint %s:%d.",
"id=%p fd=%d: connecting to endpoint %s:%u.",
(void *)socket,
socket->io_handle.data.fd,
remote_endpoint->address,
(int)remote_endpoint->port);
remote_endpoint->port);

socket->state = CONNECTING;
socket->remote_endpoint = *remote_endpoint;
Expand Down Expand Up @@ -808,24 +811,27 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint

AWS_LOGF_INFO(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: binding to %s:%d.",
"id=%p fd=%d: binding to %s:%u.",
(void *)socket,
socket->io_handle.data.fd,
local_endpoint->address,
(int)local_endpoint->port);
local_endpoint->port);

struct socket_address address;
AWS_ZERO_STRUCT(address);
socklen_t sock_size = 0;
int pton_err = 1;
bool port_is_bad = false;
graebm marked this conversation as resolved.
Show resolved Hide resolved
if (socket->options.domain == AWS_SOCKET_IPV4) {
pton_err = inet_pton(AF_INET, local_endpoint->address, &address.sock_addr_types.addr_in.sin_addr);
address.sock_addr_types.addr_in.sin_port = htons(local_endpoint->port);
port_is_bad = local_endpoint->port > UINT16_MAX;
address.sock_addr_types.addr_in.sin_port = htons((uint16_t)local_endpoint->port);
address.sock_addr_types.addr_in.sin_family = AF_INET;
sock_size = sizeof(address.sock_addr_types.addr_in);
} else if (socket->options.domain == AWS_SOCKET_IPV6) {
pton_err = inet_pton(AF_INET6, local_endpoint->address, &address.sock_addr_types.addr_in6.sin6_addr);
address.sock_addr_types.addr_in6.sin6_port = htons(local_endpoint->port);
port_is_bad = local_endpoint->port > UINT16_MAX;
address.sock_addr_types.addr_in6.sin6_port = htons((uint16_t)local_endpoint->port);
address.sock_addr_types.addr_in6.sin6_family = AF_INET6;
sock_size = sizeof(address.sock_addr_types.addr_in6);
} else if (socket->options.domain == AWS_SOCKET_LOCAL) {
Expand All @@ -836,7 +842,7 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
} else if (socket->options.domain == AWS_SOCKET_VSOCK) {
pton_err = parse_cid(local_endpoint->address, &address.sock_addr_types.vm_addr.svm_cid);
address.sock_addr_types.vm_addr.svm_family = AF_VSOCK;
address.sock_addr_types.vm_addr.svm_port = (unsigned int)local_endpoint->port;
address.sock_addr_types.vm_addr.svm_port = local_endpoint->port;
sock_size = sizeof(address.sock_addr_types.vm_addr);
#endif
} else {
Expand All @@ -848,14 +854,25 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
int errno_value = errno; /* Always cache errno before potential side-effect */
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: failed to parse address %s:%d.",
"id=%p fd=%d: failed to parse address %s:%u.",
(void *)socket,
socket->io_handle.data.fd,
local_endpoint->address,
(int)local_endpoint->port);
local_endpoint->port);
return aws_raise_error(s_convert_pton_error(pton_err, errno_value));
}

if (port_is_bad) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: illegal port value, too high for this socket type. %s:%u",
(void *)socket,
socket->io_handle.data.fd,
local_endpoint->address,
local_endpoint->port);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}

if (bind(socket->io_handle.data.fd, (struct sockaddr *)&address.sock_addr_types, sock_size) != 0) {
int errno_value = errno; /* Always cache errno before potential side-effect */
AWS_LOGF_ERROR(
Expand All @@ -882,7 +899,7 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint

AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: successfully bound to %s:%d",
"id=%p fd=%d: successfully bound to %s:%u",
(void *)socket,
socket->io_handle.data.fd,
socket->local_endpoint.address,
Expand Down Expand Up @@ -996,7 +1013,7 @@ static void s_socket_accept_event(

new_sock->local_endpoint = socket->local_endpoint;
new_sock->state = CONNECTED_READ | CONNECTED_WRITE;
uint16_t port = 0;
uint32_t port = 0;

/* get the info on the incoming socket's address */
if (in_addr.ss_family == AF_INET) {
Expand Down
Loading