diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 62bdfe610..f9774c160 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -80,6 +80,7 @@ jobs: clang-sanitizers: runs-on: ubuntu-22.04 # latest strategy: + fail-fast: false matrix: sanitizers: [",thread", ",address,undefined"] steps: diff --git a/CMakeLists.txt b/CMakeLists.txt index 55ba52bcc..707d60d7f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,8 @@ option(BUILD_RELOCATABLE_BINARIES OFF) option(BYO_CRYPTO "Don't build a tls implementation or link against a crypto interface. This feature is only for unix builds currently." OFF) +# DEBUG: directly set AWS_USE_DISPATCH_QUEUE +set (AWS_USE_DISPATCH_QUEUE ON) file(GLOB AWS_IO_HEADERS "include/aws/io/*.h" @@ -116,7 +118,8 @@ elseif (APPLE) file(GLOB AWS_IO_OS_SRC "source/bsd/*.c" "source/posix/*.c" - "source/darwin/*.c" + "source/darwin/darwin_pki_utils.c" + "source/darwin/secure_transport_tls_channel_handler.c" ) find_library(SECURITY_LIB Security) @@ -132,8 +135,16 @@ elseif (APPLE) #No choice on TLS for apple, darwinssl will always be used. list(APPEND PLATFORM_LIBS "-framework Security -framework Network") - # DEBUG WIP We will add a check here to use kqueue queue for macOS and dispatch queue for iOS - set(EVENT_LOOP_DEFINES "-DAWS_USE_DISPATCH_QUEUE -DAWS_USE_KQUEUE") + if(AWS_USE_DISPATCH_QUEUE OR IOS) + set(EVENT_LOOP_DEFINES "-DAWS_USE_DISPATCH_QUEUE" ) + message("use dispatch queue") + file(GLOB AWS_IO_DISPATCH_QUEUE_SRC + "source/darwin/dispatch_queue_event_loop.c" + ) + list(APPEND AWS_IO_OS_SRC ${AWS_IO_DISPATCH_QUEUE_SRC}) + else () + set(EVENT_LOOP_DEFINES "-DAWS_USE_KQUEUE") + endif() elseif (CMAKE_SYSTEM_NAME STREQUAL "FreeBSD" OR CMAKE_SYSTEM_NAME STREQUAL "NetBSD" OR CMAKE_SYSTEM_NAME STREQUAL "OpenBSD") file(GLOB AWS_IO_OS_HEADERS diff --git a/include/aws/io/event_loop.h b/include/aws/io/event_loop.h index 74e9c195c..e021ab4b5 100644 --- a/include/aws/io/event_loop.h +++ b/include/aws/io/event_loop.h @@ -70,7 +70,7 @@ struct aws_overlapped { void *user_data; }; -#else /* !AWS_USE_IO_COMPLETION_PORTS */ +#endif /* AWS_USE_IO_COMPLETION_PORTS */ typedef void(aws_event_loop_on_event_fn)( struct aws_event_loop *event_loop, @@ -78,8 +78,6 @@ typedef void(aws_event_loop_on_event_fn)( int events, void *user_data); -#endif /* AWS_USE_IO_COMPLETION_PORTS */ - enum aws_event_loop_style { AWS_EVENT_LOOP_STYLE_UNDEFINED = 0, AWS_EVENT_LOOP_STYLE_POLL_BASED = 1, diff --git a/include/aws/io/io.h b/include/aws/io/io.h index 719996525..f7feebdcb 100644 --- a/include/aws/io/io.h +++ b/include/aws/io/io.h @@ -16,7 +16,7 @@ AWS_PUSH_SANE_WARNING_LEVEL struct aws_io_handle; -#if AWS_USE_DISPATCH_QUEUE +#ifdef AWS_USE_DISPATCH_QUEUE typedef void aws_io_set_queue_on_handle_fn(struct aws_io_handle *handle, void *queue); typedef void aws_io_clear_queue_on_handle_fn(struct aws_io_handle *handle); #endif /* AWS_USE_DISPATCH_QUEUE */ diff --git a/include/aws/io/private/tls_channel_handler_shared.h b/include/aws/io/private/tls_channel_handler_shared.h index 4755cd8d0..034321da1 100644 --- a/include/aws/io/private/tls_channel_handler_shared.h +++ b/include/aws/io/private/tls_channel_handler_shared.h @@ -33,6 +33,12 @@ struct secure_transport_ctx { bool verify_peer; }; +enum aws_tls_handler_read_state { + AWS_TLS_HANDLER_OPEN, + AWS_TLS_HANDLER_READ_SHUTTING_DOWN, + AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE, +}; + AWS_EXTERN_C_BEGIN AWS_IO_API void aws_tls_channel_handler_shared_init( diff --git a/source/channel.c b/source/channel.c index c387844f4..55903fc1d 100644 --- a/source/channel.c +++ b/source/channel.c @@ -828,7 +828,7 @@ static void s_window_update_task(struct aws_channel_task *channel_task, void *ar channel->window_update_scheduled = false; - if (status == AWS_TASK_STATUS_RUN_READY && channel->channel_state < AWS_CHANNEL_SHUTTING_DOWN) { + if (status == AWS_TASK_STATUS_RUN_READY && channel->channel_state < AWS_CHANNEL_SHUT_DOWN) { /* get the right-most slot to start the updates. */ struct aws_channel_slot *slot = channel->first; while (slot->adj_right) { @@ -858,7 +858,7 @@ static void s_window_update_task(struct aws_channel_task *channel_task, void *ar int aws_channel_slot_increment_read_window(struct aws_channel_slot *slot, size_t window) { - if (slot->channel->read_back_pressure_enabled && slot->channel->channel_state < AWS_CHANNEL_SHUTTING_DOWN) { + if (slot->channel->read_back_pressure_enabled && slot->channel->channel_state < AWS_CHANNEL_SHUT_DOWN) { slot->current_window_update_batch_size = aws_add_size_saturating(slot->current_window_update_batch_size, window); diff --git a/source/darwin/dispatch_queue_event_loop.c b/source/darwin/dispatch_queue_event_loop.c index fde6f5b42..478634e43 100644 --- a/source/darwin/dispatch_queue_event_loop.c +++ b/source/darwin/dispatch_queue_event_loop.c @@ -96,7 +96,6 @@ struct scheduled_service_entry *scheduled_service_entry_new(struct aws_event_loo // may only be called when the dispatch event loop synced data lock is held void scheduled_service_entry_destroy(struct scheduled_service_entry *entry) { - // AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Destroy service entry.", (void *)entry->loop); if (aws_linked_list_node_is_in_list(&entry->node)) { aws_linked_list_remove(&entry->node); } @@ -120,18 +119,14 @@ bool should_schedule_iteration(struct aws_linked_list *scheduled_iterations, uin return entry->timestamp > proposed_iteration_time; } -static void s_finalize(void *context) { - struct aws_event_loop *event_loop = context; - struct dispatch_loop *dispatch_loop = event_loop->impl_data; - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Dispatch Queue Finalized", (void *)event_loop); - aws_ref_count_release(&dispatch_loop->ref_count); -} - static void s_dispatch_event_loop_destroy(void *context) { // release dispatch loop + struct aws_event_loop *event_loop = context; struct dispatch_loop *dispatch_loop = event_loop->impl_data; + AWS_LOGF_DEBUG(AWS_LS_IO_EVENT_LOOP, "id=%p: Destroy Dispatch Queue Event Loop.", (void *)event_loop); + aws_mutex_clean_up(&dispatch_loop->synced_data.lock); aws_mem_release(dispatch_loop->allocator, dispatch_loop); aws_event_loop_clean_up_base(event_loop); @@ -149,7 +144,7 @@ struct aws_event_loop *aws_event_loop_new_dispatch_queue_with_options( struct aws_event_loop *loop = aws_mem_calloc(alloc, 1, sizeof(struct aws_event_loop)); - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Initializing dispatch_queue event-loop", (void *)loop); + AWS_LOGF_DEBUG(AWS_LS_IO_EVENT_LOOP, "id=%p: Initializing dispatch_queue event-loop", (void *)loop); if (aws_event_loop_init_base(loop, alloc, options->clock)) { goto clean_up_loop; } @@ -184,21 +179,7 @@ struct aws_event_loop *aws_event_loop_new_dispatch_queue_with_options( loop->impl_data = dispatch_loop; loop->vtable = &s_vtable; - /* The following code is an equivalent of the next commented out section. The difference is, async_and_wait - * runs in the callers thread, NOT the event-loop's thread and so we need to use the blocks API. - dispatch_async_and_wait(dispatch_loop->dispatch_queue, ^{ - dispatch_loop->running_thread_id = aws_thread_current_thread_id(); - }); */ - // dispatch_block_t block = dispatch_block_create(0, ^{ - // }); - // dispatch_async(dispatch_loop->dispatch_queue, block); - // dispatch_block_wait(block, DISPATCH_TIME_FOREVER); - // Block_release(block); - - dispatch_set_context(dispatch_loop->dispatch_queue, loop); - // Definalizer will be called on dispatch queue ref drop to 0 - dispatch_set_finalizer_f(dispatch_loop->dispatch_queue, &s_finalize); - + // manually increament the thread count, so the library will wait for dispatch queue releasing aws_thread_increment_unjoined_count(); return loop; @@ -218,7 +199,7 @@ struct aws_event_loop *aws_event_loop_new_dispatch_queue_with_options( } static void s_destroy(struct aws_event_loop *event_loop) { - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Destroying event_loop", (void *)event_loop); + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: Destroying Dispatch Queue Event Loop", (void *)event_loop); struct dispatch_loop *dispatch_loop = event_loop->impl_data; @@ -230,8 +211,6 @@ static void s_destroy(struct aws_event_loop *event_loop) { aws_task_scheduler_clean_up(&dispatch_loop->scheduler); aws_mutex_lock(&dispatch_loop->synced_data.lock); - dispatch_loop->synced_data.suspended = true; - while (!aws_linked_list_empty(&dispatch_loop->synced_data.cross_thread_tasks)) { struct aws_linked_list_node *node = aws_linked_list_pop_front(&dispatch_loop->synced_data.cross_thread_tasks); struct aws_task *task = AWS_CONTAINER_OF(node, struct aws_task, node); @@ -244,7 +223,6 @@ static void s_destroy(struct aws_event_loop *event_loop) { task->fn(task, task->arg, AWS_TASK_STATUS_CANCELED); } - AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Destroy event loop, clean up service entry.", (void *)event_loop); while (!aws_linked_list_empty(&dispatch_loop->synced_data.scheduling_state.scheduled_services)) { struct aws_linked_list_node *node = aws_linked_list_pop_front(&dispatch_loop->synced_data.scheduling_state.scheduled_services); @@ -252,11 +230,15 @@ static void s_destroy(struct aws_event_loop *event_loop) { scheduled_service_entry_destroy(entry); } + dispatch_loop->synced_data.suspended = true; aws_mutex_unlock(&dispatch_loop->synced_data.lock); }); /* we don't want it stopped while shutting down. dispatch_release will fail on a suspended loop. */ dispatch_release(dispatch_loop->dispatch_queue); + + AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: Releasing Dispatch Queue.", (void *)event_loop); + aws_ref_count_release(&dispatch_loop->ref_count); } static int s_wait_for_stop_completion(struct aws_event_loop *event_loop) { @@ -286,6 +268,8 @@ static int s_stop(struct aws_event_loop *event_loop) { if (!dispatch_loop->synced_data.suspended) { dispatch_loop->synced_data.suspended = true; AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Stopping event-loop thread.", (void *)event_loop); + // Suspend will increase the dispatch reference count. It is required to call resume before + // releasing the dispatch queue. dispatch_suspend(dispatch_loop->dispatch_queue); } aws_mutex_unlock(&dispatch_loop->synced_data.lock); @@ -314,7 +298,6 @@ bool begin_iteration(struct scheduled_service_entry *entry) { // mark us as running an iteration and remove from the pending list dispatch_loop->synced_data.scheduling_state.is_executing_iteration = true; - // AWS_LOGF_INFO(AWS_LS_IO_EVENT_LOOP, "id=%p: Remove poped service entry node.", (void *)entry->loop); aws_linked_list_remove(&entry->node); should_execute_iteration = true; @@ -342,9 +325,9 @@ void end_iteration(struct scheduled_service_entry *entry) { // no cross thread tasks, so check internal time-based scheduler uint64_t next_task_time = 0; /* we already know it has tasks, we just scheduled one. We just want the next run time. */ - aws_task_scheduler_has_tasks(&loop->scheduler, &next_task_time); + bool has_task = aws_task_scheduler_has_tasks(&loop->scheduler, &next_task_time); - if (next_task_time > 0) { + if (has_task) { // only schedule an iteration if there isn't an existing dispatched iteration for the next task time or // earlier if (should_schedule_iteration(&loop->synced_data.scheduling_state.scheduled_services, next_task_time)) { @@ -353,11 +336,7 @@ void end_iteration(struct scheduled_service_entry *entry) { } } -done: - // AWS_LOGF_INFO( - // AWS_LS_IO_EVENT_LOOP, "id=%p: End of Iteration, start to destroy service entry.", (void *)entry->loop); aws_mutex_unlock(&loop->synced_data.lock); - scheduled_service_entry_destroy(entry); } @@ -375,17 +354,11 @@ void run_iteration(void *context) { aws_event_loop_register_tick_start(event_loop); // run the full iteration here: local cross-thread tasks - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: processing cross-thread tasks", (void *)dispatch_loop); while (!aws_linked_list_empty(&dispatch_loop->local_cross_thread_tasks)) { struct aws_linked_list_node *node = aws_linked_list_pop_front(&dispatch_loop->local_cross_thread_tasks); struct aws_task *task = AWS_CONTAINER_OF(node, struct aws_task, node); - AWS_LOGF_TRACE( - AWS_LS_IO_EVENT_LOOP, - "id=%p: task %p pulled to event-loop, scheduling now.", - (void *)dispatch_loop, - (void *)task); /* Timestamp 0 is used to denote "now" tasks */ if (task->timestamp == 0) { aws_task_scheduler_schedule_now(&dispatch_loop->scheduler, task); @@ -397,14 +370,13 @@ void run_iteration(void *context) { // run all scheduled tasks uint64_t now_ns = 0; aws_event_loop_current_clock_time(event_loop, &now_ns); - AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: running scheduled tasks.", (void *)dispatch_loop); aws_task_scheduler_run_all(&dispatch_loop->scheduler, now_ns); aws_event_loop_register_tick_end(event_loop); end_iteration(entry); } -// checks if a new iteration task needs to be scheduled, given a target timestamp +// Checks if a new iteration task needs to be scheduled, given a target timestamp // If so, submits an iteration task to dispatch queue and registers the pending // execution in the event loop's list of scheduled iterations. // The function should be wrapped with dispatch_loop->synced_data->lock @@ -423,24 +395,18 @@ void try_schedule_new_iteration(struct aws_event_loop *loop, uint64_t timestamp) static void s_schedule_task_common(struct aws_event_loop *event_loop, struct aws_task *task, uint64_t run_at_nanos) { struct dispatch_loop *dispatch_loop = event_loop->impl_data; - if (aws_linked_list_node_is_in_list(&task->node)) { - if (run_at_nanos == 0) { - aws_task_scheduler_schedule_now(&dispatch_loop->scheduler, task); - } else { - aws_task_scheduler_schedule_future(&dispatch_loop->scheduler, task, run_at_nanos); - } - return; - } - aws_mutex_lock(&dispatch_loop->synced_data.lock); bool should_schedule = false; bool is_empty = aws_linked_list_empty(&dispatch_loop->synced_data.cross_thread_tasks); + task->timestamp = run_at_nanos; + // We dont have control to dispatch queue thread, threat all tasks are threated as cross thread tasks aws_linked_list_push_back(&dispatch_loop->synced_data.cross_thread_tasks, &task->node); if (is_empty) { if (!dispatch_loop->synced_data.scheduling_state.is_executing_iteration) { - if (should_schedule_iteration(&dispatch_loop->synced_data.scheduling_state.scheduled_services, 0)) { + if (should_schedule_iteration( + &dispatch_loop->synced_data.scheduling_state.scheduled_services, run_at_nanos)) { should_schedule = true; } } @@ -464,10 +430,7 @@ static void s_schedule_task_future(struct aws_event_loop *event_loop, struct aws static void s_cancel_task(struct aws_event_loop *event_loop, struct aws_task *task) { AWS_LOGF_TRACE(AWS_LS_IO_EVENT_LOOP, "id=%p: cancelling task %p", (void *)event_loop, (void *)task); struct dispatch_loop *dispatch_loop = event_loop->impl_data; - - dispatch_async(dispatch_loop->dispatch_queue, ^{ - aws_task_scheduler_cancel_task(&dispatch_loop->scheduler, task); - }); + aws_task_scheduler_cancel_task(&dispatch_loop->scheduler, task); } static int s_connect_to_dispatch_queue(struct aws_event_loop *event_loop, struct aws_io_handle *handle) { @@ -494,7 +457,10 @@ static int s_unsubscribe_from_io_events(struct aws_event_loop *event_loop, struc return AWS_OP_SUCCESS; } +// The dispatch queue will assign the task block to threads, we will threat all +// tasks as cross thread tasks. Ignore the caller thread verification for apple +// dispatch queue. static bool s_is_on_callers_thread(struct aws_event_loop *event_loop) { - // DEBUG: for now always return true for caller thread validation + (void)event_loop; return true; -} \ No newline at end of file +} diff --git a/source/darwin/nw_socket.c b/source/darwin/nw_socket.c index 3782abaff..2b5421822 100644 --- a/source/darwin/nw_socket.c +++ b/source/darwin/nw_socket.c @@ -108,18 +108,28 @@ enum socket_state { CLOSED, }; +struct nw_socket_connect_args { + struct aws_task task; + struct aws_allocator *allocator; + struct aws_socket *socket; +}; + struct nw_socket { struct aws_allocator *allocator; struct aws_ref_count ref_count; - nw_connection_t *nw_connection; + nw_connection_t nw_connection; nw_parameters_t socket_options_to_params; struct aws_linked_list read_queue; int last_error; aws_socket_on_readable_fn *on_readable; void *on_readable_user_data; bool setup_run; + bool setup_closing; bool read_queued; bool is_listener; + struct nw_socket_connect_args *connect_args; + aws_socket_on_connection_result_fn *on_connection_result_fn; + void *connect_accept_user_data; }; struct socket_address { @@ -458,6 +468,7 @@ static void s_socket_impl_destroy(void *sock_ptr) { nw_socket->nw_connection = NULL; } + aws_mem_release(nw_socket->allocator, nw_socket->connect_args); aws_mem_release(nw_socket->allocator, nw_socket); nw_socket = NULL; } @@ -497,6 +508,40 @@ static void s_client_clear_dispatch_queue(struct aws_io_handle *handle) { nw_connection_set_state_changed_handler(handle->data.handle, NULL); } +static void s_handle_socket_timeout(struct aws_task *task, void *args, aws_task_status status) { + (void)task; + (void)status; + + if (status == AWS_TASK_STATUS_CANCELED) { + // We will clean up the task and args on socket destory. + return; + } + struct nw_socket_connect_args *socket_args = args; + + AWS_LOGF_TRACE(AWS_LS_IO_SOCKET, "task_id=%p: timeout task triggered, evaluating timeouts.", (void *)task); + /* successful connection will have nulled out connect_args->socket */ + if (socket_args->socket) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p handle=%p: timed out, shutting down.", + (void *)socket_args->socket, + socket_args->socket->io_handle.data.handle); + + socket_args->socket->state = TIMEDOUT; + int error_code = AWS_IO_SOCKET_TIMEOUT; + + // socket_args->socket->event_loop = NULL; + struct nw_socket *socket_impl = socket_args->socket->impl; + + aws_raise_error(error_code); + struct aws_socket *socket = socket_args->socket; + /*socket close sets socket_args->socket to NULL and + * socket_impl->connect_args to NULL. */ + aws_socket_close(socket); + socket_impl->on_connection_result_fn(socket, error_code, socket_impl->connect_accept_user_data); + } +} + static int s_socket_connect_fn( struct aws_socket *socket, const struct aws_socket_endpoint *remote_endpoint, @@ -616,6 +661,27 @@ static int s_socket_connect_fn( aws_event_loop_connect_handle_to_completion_port(event_loop, &socket->io_handle); socket->event_loop = event_loop; + nw_socket->on_connection_result_fn = on_connection_result; + nw_socket->connect_accept_user_data = user_data; + + struct nw_socket *socket_impl = socket->impl; + + nw_socket->connect_args = aws_mem_calloc(socket->allocator, 1, sizeof(struct nw_socket_connect_args)); + if (!nw_socket->connect_args) { + return AWS_OP_ERR; + } + + nw_socket->connect_args->socket = socket; + nw_socket->connect_args->allocator = socket->allocator; + + aws_task_init( + &nw_socket->connect_args->task, + s_handle_socket_timeout, + nw_socket->connect_args, + "NWSocketConnectionTimeoutTask"); + + nw_connection_t handle = socket->io_handle.data.handle; + /* set a handler for socket state changes. This is where we find out if the connection timed out, was successful, * was disconnected etc .... */ nw_connection_set_state_changed_handler( @@ -712,6 +778,28 @@ static int s_socket_connect_fn( nw_connection_start(socket->io_handle.data.handle); nw_retain(socket->io_handle.data.handle); + /* schedule a task to run at the connect timeout interval, if this task runs before the connect + * happens, we consider that a timeout. */ + + uint64_t timeout = 0; + aws_event_loop_current_clock_time(event_loop, &timeout); + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, + "id=%p handle=%p: start connection at %llu.", + (void *)socket, + socket->io_handle.data.handle, + (unsigned long long)timeout); + timeout += + aws_timestamp_convert(socket->options.connect_timeout_ms, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL); + AWS_LOGF_DEBUG( + AWS_LS_IO_SOCKET, + "id=%p hanlde=%p: scheduling timeout task for %llu.", + (void *)socket, + socket->io_handle.data.handle, + (unsigned long long)timeout); + nw_socket->connect_args->task.timestamp = timeout; + aws_event_loop_schedule_task_future(event_loop, &nw_socket->connect_args->task, timeout); + return AWS_OP_SUCCESS; } @@ -930,11 +1018,13 @@ static int s_socket_close_fn(struct aws_socket *socket) { if (nw_socket->is_listener) { nw_listener_set_state_changed_handler(socket->io_handle.data.handle, NULL); nw_listener_cancel(socket->io_handle.data.handle); + } else { /* Setting to NULL removes previously set handler from nw_connection_t */ nw_connection_set_state_changed_handler(socket->io_handle.data.handle, NULL); nw_connection_cancel(socket->io_handle.data.handle); } + nw_socket->setup_closing = true; return AWS_OP_SUCCESS; } @@ -1003,6 +1093,18 @@ static void s_schedule_next_read(struct aws_socket *socket) { struct aws_allocator *allocator = socket->allocator; struct aws_linked_list *list = &nw_socket->read_queue; + if (!(socket->state & CONNECTED_READ)) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p handle=%p: cannot read to because it is not connected", + (void *)socket, + socket->io_handle.data.handle); + return aws_raise_error(AWS_IO_SOCKET_NOT_CONNECTED); + } + + // Acquire nw_socket after we call connection receive, and released it when handler is called. + aws_ref_count_acquire(&nw_socket->ref_count); + /* read and let me know when you've done it. */ nw_connection_receive( socket->io_handle.data.handle, @@ -1013,7 +1115,11 @@ static void s_schedule_next_read(struct aws_socket *socket) { AWS_LOGF_TRACE( AWS_LS_IO_SOCKET, "id=%p handle=%p: read cb invoked", (void *)socket, socket->io_handle.data.handle); - if (!error || nw_error_get_error_code(error) == 0) { + if (nw_socket->setup_closing) { + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, "id=%p handle=%p: socket closed", (void *)socket, socket->io_handle.data.handle); + aws_raise_error(AWS_IO_SOCKET_CLOSED); + } else if (!error || nw_error_get_error_code(error) == 0) { if (data) { struct read_queue_node *node = aws_mem_calloc(allocator, 1, sizeof(struct read_queue_node)); node->allocator = allocator; @@ -1044,10 +1150,7 @@ static void s_schedule_next_read(struct aws_socket *socket) { nw_socket->on_readable(socket, error_code, nw_socket->on_readable_user_data); } - // DEBUG WIP these may or may not be necessary. release on error seems okay but - // release on context or data here appears to double release. - // nw_release(context); - nw_release(error); + aws_ref_count_release(&nw_socket->ref_count); }); } @@ -1168,31 +1271,50 @@ static int s_socket_write_fn( return aws_raise_error(AWS_IO_SOCKET_NOT_CONNECTED); } + struct nw_socket *nw_socket = socket->impl; + aws_ref_count_acquire(&nw_socket->ref_count); + nw_connection_t handle = socket->io_handle.data.handle; + AWS_ASSERT(written_fn); dispatch_data_t data = dispatch_data_create(cursor->ptr, cursor->len, NULL, DISPATCH_DATA_DESTRUCTOR_FREE); printf("\nWriting to SOCKET\n\n"); nw_connection_send( - socket->io_handle.data.handle, data, _nw_content_context_default_message, true, ^(nw_error_t error) { + handle, data, _nw_content_context_default_message, true, ^(nw_error_t error) { AWS_LOGF_TRACE( AWS_LS_IO_SOCKET, "id=%p handle=%p: processing write requests, called from aws_socket_write", (void *)socket, - socket->io_handle.data.handle); + handle); + + if (nw_socket->setup_closing) { + AWS_LOGF_TRACE( + AWS_LS_IO_SOCKET, + "id=%p handle=%p: socket closed", + (void *)socket, + handle); + written_fn(socket, 0, 0, user_data); + goto nw_socket_release; + } + + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "id=%p handle=%p: DEBUG:: callback writing message: %p", + (void *)socket, + handle, user_data); int error_code = !error || nw_error_get_error_code(error) == 0 ? AWS_OP_SUCCESS : s_determine_socket_error(nw_error_get_error_code(error)); if (error_code) { - struct nw_socket *nw_socket = socket->impl; nw_socket->last_error = error_code; aws_raise_error(error_code); AWS_LOGF_ERROR( AWS_LS_IO_SOCKET, "id=%p handle=%p: error during write %d", (void *)socket, - socket->io_handle.data.handle, + handle, error_code); } @@ -1201,9 +1323,11 @@ static int s_socket_write_fn( AWS_LS_IO_SOCKET, "id=%p handle=%p: send written size %d", (void *)socket, - socket->io_handle.data.handle, + handle, (int)written_size); written_fn(socket, error_code, !error_code ? written_size : 0, user_data); +nw_socket_release: + aws_ref_count_release(&nw_socket->ref_count); }); return AWS_OP_SUCCESS; diff --git a/source/darwin/secure_transport_tls_channel_handler.c b/source/darwin/secure_transport_tls_channel_handler.c index 3267e0989..7898f40b7 100644 --- a/source/darwin/secure_transport_tls_channel_handler.c +++ b/source/darwin/secure_transport_tls_channel_handler.c @@ -113,6 +113,8 @@ struct secure_transport_handler { bool negotiation_finished; bool verify_peer; bool read_task_pending; + enum aws_tls_handler_read_state read_state; + int delay_shutdown_error_code; }; static OSStatus s_read_cb(SSLConnectionRef conn, void *data, size_t *len) { @@ -550,6 +552,41 @@ static int s_process_write_message( return AWS_OP_SUCCESS; } +static void s_run_read(struct aws_channel_task *task, void *arg, aws_task_status status); + +static void s_initialize_read_delay_shutdown( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + int error_code) { + struct secure_transport_handler *secure_transport_handler = handler->impl; + /** + * In case of if we have any queued data in the handler after negotiation and we start to shutdown, + * make sure we pass those data down the pipeline before we complete the shutdown. + */ + AWS_LOGF_DEBUG( + AWS_LS_IO_TLS, + "id=%p: TLS handler still have pending data to be delivered during shutdown. Wait until downstream " + "reads the data.", + (void *)handler); + if (aws_channel_slot_downstream_read_window(slot) == 0) { + AWS_LOGF_WARN( + AWS_LS_IO_TLS, + "id=%p: TLS shutdown delayed. Pending data cannot be processed until the flow-control window opens. " + " Your application may hang if the read window never opens", + (void *)handler); + } + secure_transport_handler->read_state = AWS_TLS_HANDLER_READ_SHUTTING_DOWN; + secure_transport_handler->delay_shutdown_error_code = error_code; + if (!secure_transport_handler->read_task_pending) { + /* Kick off read, in case data arrives with TLS negotiation. Shutdown starts right after negotiation. + * Nothing will kick off read in that case. */ + secure_transport_handler->read_task_pending = true; + aws_channel_task_init( + &secure_transport_handler->read_task, s_run_read, handler, "darwin_channel_handler_read_on_delay_shutdown"); + aws_channel_schedule_task_now(slot->channel, &secure_transport_handler->read_task); + } +} + static int s_handle_shutdown( struct aws_channel_handler *handler, struct aws_channel_slot *slot, @@ -558,24 +595,30 @@ static int s_handle_shutdown( bool abort_immediately) { struct secure_transport_handler *secure_transport_handler = handler->impl; - if (dir == AWS_CHANNEL_DIR_WRITE) { + if (dir == AWS_CHANNEL_DIR_READ) { + AWS_LOGF_DEBUG( + AWS_LS_IO_TLS, "id=%p: shutting down read direction with error %d.", (void *)handler, error_code); + if (!abort_immediately && secure_transport_handler->negotiation_finished && + !aws_linked_list_empty(&secure_transport_handler->input_queue) && slot->adj_right) { + s_initialize_read_delay_shutdown(handler, slot, error_code); + /* Early out, not complete the shutdown process for the handler until the handler processes the pending + * data. */ + return AWS_OP_SUCCESS; + } + secure_transport_handler->read_state = AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE; + } else { + /* Shutdown in write direction */ if (!abort_immediately && error_code != AWS_IO_SOCKET_CLOSED) { AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: shutting down write direction.", (void *)handler); SSLClose(secure_transport_handler->ctx); } - } else { - AWS_LOGF_DEBUG( - AWS_LS_IO_TLS, - "id=%p: shutting down read direction with error %d. Flushing queues.", - (void *)handler, - error_code); - while (!aws_linked_list_empty(&secure_transport_handler->input_queue)) { - struct aws_linked_list_node *node = aws_linked_list_pop_front(&secure_transport_handler->input_queue); - struct aws_io_message *message = AWS_CONTAINER_OF(node, struct aws_io_message, queueing_handle); - aws_mem_release(message->allocator, message); - } } - + /* Flushing queues */ + while (!aws_linked_list_empty(&secure_transport_handler->input_queue)) { + struct aws_linked_list_node *node = aws_linked_list_pop_front(&secure_transport_handler->input_queue); + struct aws_io_message *message = AWS_CONTAINER_OF(node, struct aws_io_message, queueing_handle); + aws_mem_release(message->allocator, message); + } return aws_channel_slot_on_handler_shutdown_complete(slot, dir, error_code, abort_immediately); } @@ -585,6 +628,12 @@ static int s_process_read_message( struct aws_io_message *message) { struct secure_transport_handler *secure_transport_handler = handler->impl; + if (secure_transport_handler->read_state == AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE) { + if (message) { + aws_mem_release(message->allocator, message); + } + return AWS_OP_SUCCESS; + } if (message) { aws_linked_list_push_back(&secure_transport_handler->input_queue, &message->queueing_handle); @@ -610,59 +659,67 @@ static int s_process_read_message( AWS_LS_IO_TLS, "id=%p: downstream window is %llu", (void *)handler, (unsigned long long)downstream_window); size_t processed = 0; - OSStatus status = noErr; - while (processed < downstream_window && status == noErr) { + int shutdown_error_code = 0; + while (processed < downstream_window) { struct aws_io_message *outgoing_read_message = aws_channel_acquire_message_from_pool( slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, downstream_window - processed); size_t read = 0; - status = SSLRead( + OSStatus status = SSLRead( secure_transport_handler->ctx, outgoing_read_message->message_data.buffer, outgoing_read_message->message_data.capacity, &read); AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: bytes read %llu", (void *)handler, (unsigned long long)read); - if (read <= 0) { - aws_mem_release(outgoing_read_message->allocator, outgoing_read_message); - - if (status != errSSLWouldBlock) { - AWS_LOGF_ERROR( - AWS_LS_IO_TLS, - "id=%p: error reported during SSLRead. OSStatus code %d", - (void *)handler, - (int)status); + if (read > 0) { + processed += read; + outgoing_read_message->message_data.len = read; - if (status != errSSLClosedGraceful) { - aws_raise_error(AWS_IO_TLS_ERROR_READ_FAILURE); - aws_channel_shutdown(secure_transport_handler->parent_slot->channel, AWS_IO_TLS_ERROR_READ_FAILURE); - } else { - AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: connection shutting down gracefully.", (void *)handler); - aws_channel_shutdown(secure_transport_handler->parent_slot->channel, AWS_ERROR_SUCCESS); - } + if (secure_transport_handler->on_data_read) { + secure_transport_handler->on_data_read( + handler, slot, &outgoing_read_message->message_data, secure_transport_handler->user_data); } - continue; - }; - - processed += read; - outgoing_read_message->message_data.len = read; - if (secure_transport_handler->on_data_read) { - secure_transport_handler->on_data_read( - handler, slot, &outgoing_read_message->message_data, secure_transport_handler->user_data); - } - - if (slot->adj_right) { - if (aws_channel_slot_send_message(slot, outgoing_read_message, AWS_CHANNEL_DIR_READ)) { + if (slot->adj_right) { + if (aws_channel_slot_send_message(slot, outgoing_read_message, AWS_CHANNEL_DIR_READ)) { + aws_mem_release(outgoing_read_message->allocator, outgoing_read_message); + shutdown_error_code = aws_last_error(); + goto shutdown_channel; + } + /* outgoing message was pushed to the input_queue, so this handler owns it now */ + } else { aws_mem_release(outgoing_read_message->allocator, outgoing_read_message); - aws_channel_shutdown(secure_transport_handler->parent_slot->channel, aws_last_error()); - /* incoming message was pushed to the input_queue, so this handler owns it now */ - return AWS_OP_SUCCESS; } } else { + /* Nothing was read */ aws_mem_release(outgoing_read_message->allocator, outgoing_read_message); } + + switch (status) { + case errSSLWouldBlock: + if (secure_transport_handler->read_state == AWS_TLS_HANDLER_READ_SHUTTING_DOWN) { + /* Propagate the shutdown as we blocked now. */ + goto shutdown_channel; + } else { + break; + } + case errSSLClosedGraceful: + AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: connection shutting down gracefully.", (void *)handler); + goto shutdown_channel; + case noErr: + /* continue the while loop */ + continue; + default: + /* unexpected error happened */ + aws_raise_error(AWS_IO_TLS_ERROR_READ_FAILURE); + shutdown_error_code = AWS_IO_TLS_ERROR_READ_FAILURE; + goto shutdown_channel; + } + + /* Break the while loop */ + break; } AWS_LOGF_TRACE( AWS_LS_IO_TLS, @@ -671,6 +728,21 @@ static int s_process_read_message( (unsigned long long)downstream_window - processed); return AWS_OP_SUCCESS; + +shutdown_channel: + if (secure_transport_handler->read_state == AWS_TLS_HANDLER_READ_SHUTTING_DOWN) { + if (secure_transport_handler->delay_shutdown_error_code != 0) { + /* Propagate the original error code if it is set. */ + shutdown_error_code = secure_transport_handler->delay_shutdown_error_code; + } + /* Continue the shutdown process delayed before. */ + secure_transport_handler->read_state = AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE; + aws_channel_slot_on_handler_shutdown_complete(slot, AWS_CHANNEL_DIR_READ, shutdown_error_code, false); + } else { + /* Starts the shutdown process */ + aws_channel_shutdown(slot->channel, shutdown_error_code); + } + return AWS_OP_SUCCESS; } static void s_run_read(struct aws_channel_task *task, void *arg, aws_task_status status) { @@ -685,6 +757,9 @@ static void s_run_read(struct aws_channel_task *task, void *arg, aws_task_status static int s_increment_read_window(struct aws_channel_handler *handler, struct aws_channel_slot *slot, size_t size) { struct secure_transport_handler *secure_transport_handler = handler->impl; + if (secure_transport_handler->read_state == AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE) { + return AWS_OP_SUCCESS; + } AWS_LOGF_TRACE( AWS_LS_IO_TLS, "id=%p: increment read window message received %llu", (void *)handler, (unsigned long long)size); @@ -706,13 +781,9 @@ static int s_increment_read_window(struct aws_channel_handler *handler, struct a aws_channel_slot_increment_read_window(slot, window_update_size); } - if (secure_transport_handler->negotiation_finished && !secure_transport_handler->read_task.node.next) { + if (secure_transport_handler->negotiation_finished && !secure_transport_handler->read_task_pending) { /* TLS requires full records before it can decrypt anything. As a result we need to check everything we've * buffered instead of just waiting on a read from the socket, or we'll hit a deadlock. - * - * We have messages in a queue and they need to be run after the socket has popped (even if it didn't have data - * to read). Alternatively, s2n reads entire records at a time, so we'll need to grab whatever we can and we - * have no idea what's going on inside there. So we need to attempt another read. */ secure_transport_handler->read_task_pending = true; aws_channel_task_init( diff --git a/source/event_loop.c b/source/event_loop.c index dd0768eb1..f3a7197db 100644 --- a/source/event_loop.c +++ b/source/event_loop.c @@ -24,27 +24,21 @@ static const struct aws_event_loop_configuration s_available_configurations[] = .style = AWS_EVENT_LOOP_STYLE_COMPLETION_PORT_BASED, }, #endif -#if AWS_USE_KQUEUE - { - .name = "BSD Edge-Triggered KQueue", - .event_loop_new_fn = aws_event_loop_new_kqueue_with_options, - .style = AWS_EVENT_LOOP_STYLE_POLL_BASED, - .is_default = true, - }, -#endif -#if TARGET_OS_MAC +#if AWS_USE_DISPATCH_QUEUE /* use kqueue on OSX and dispatch_queues everywhere else */ { .name = "Apple Dispatch Queue", .event_loop_new_fn = aws_event_loop_new_dispatch_queue_with_options, .style = AWS_EVENT_LOOP_STYLE_COMPLETION_PORT_BASED, -# if TARGET_OS_OSX - /* DEBUG WIP temp set the dispatch queue to be default. */ .is_default = true, - // .is_default = false, -# else + }, +#endif +#if AWS_USE_KQUEUE + { + .name = "BSD Edge-Triggered KQueue", + .event_loop_new_fn = aws_event_loop_new_kqueue_with_options, + .style = AWS_EVENT_LOOP_STYLE_POLL_BASED, .is_default = true, -# endif }, #endif #if AWS_USE_EPOLL @@ -488,10 +482,10 @@ size_t aws_event_loop_get_load_factor(struct aws_event_loop *event_loop) { return aws_atomic_load_int(&event_loop->current_load_factor); } -// DEBUG: TODO: WORKAROUND THE CALLER THREAD VALIDATION ON DISPATCH QUEUE. +// As dispatch queue has ARC support, we could directly release the dispatch queue event loop. Disable the +// caller thread validation on dispatch queue. #ifndef AWS_USE_DISPATCH_QUEUE -# define AWS_EVENT_LOOP_NOT_CALLER_THREAD(eventloop) -AWS_ASSERT(!aws_event_loop_thread_is_callers_thread(eventloop)); +# define AWS_EVENT_LOOP_NOT_CALLER_THREAD(eventloop) AWS_ASSERT(!aws_event_loop_thread_is_callers_thread(eventloop)); #else # define AWS_EVENT_LOOP_NOT_CALLER_THREAD(eventloop) #endif @@ -502,7 +496,6 @@ void aws_event_loop_destroy(struct aws_event_loop *event_loop) { } AWS_ASSERT(event_loop->vtable && event_loop->vtable->destroy); - // DEBUG: TODO: WORKAROUND THE CALLER THREAD VALIDATION ON DISPATCH QUEUE. AWS_EVENT_LOOP_NOT_CALLER_THREAD(event_loop); event_loop->vtable->destroy(event_loop); diff --git a/source/s2n/s2n_tls_channel_handler.c b/source/s2n/s2n_tls_channel_handler.c index d72449eab..355a64b1b 100644 --- a/source/s2n/s2n_tls_channel_handler.c +++ b/source/s2n/s2n_tls_channel_handler.c @@ -35,12 +35,6 @@ static const char *s_default_ca_dir = NULL; static const char *s_default_ca_file = NULL; -struct s2n_delayed_shutdown_task { - struct aws_channel_task task; - struct aws_channel_slot *slot; - int error; -}; - struct s2n_handler { struct aws_channel_handler handler; struct aws_tls_channel_handler_shared shared_state; @@ -63,7 +57,11 @@ struct s2n_handler { NEGOTIATION_FAILED, NEGOTIATION_SUCCEEDED, } state; - struct s2n_delayed_shutdown_task delayed_shutdown_task; + struct aws_channel_task read_task; + bool read_task_pending; + enum aws_tls_handler_read_state read_state; + int shutdown_error_code; + struct aws_channel_task delayed_shutdown_task; }; struct s2n_ctx { @@ -523,6 +521,13 @@ static int s_s2n_handler_process_read_message( struct s2n_handler *s2n_handler = handler->impl; + if (s2n_handler->read_state == AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE) { + if (message) { + aws_mem_release(message->allocator, message); + } + return AWS_OP_SUCCESS; + } + if (AWS_UNLIKELY(s2n_handler->state == NEGOTIATION_FAILED)) { return aws_raise_error(AWS_IO_TLS_ERROR_NEGOTIATION_FAILURE); } @@ -532,7 +537,7 @@ static int s_s2n_handler_process_read_message( if (s2n_handler->state == NEGOTIATION_ONGOING) { size_t message_len = message->message_data.len; - if (!s_drive_negotiation(handler)) { + if (s_drive_negotiation(handler) == AWS_OP_SUCCESS) { aws_channel_slot_increment_read_window(slot, message_len); } else { aws_channel_shutdown(s2n_handler->slot->channel, AWS_IO_TLS_ERROR_NEGOTIATION_FAILURE); @@ -546,6 +551,7 @@ static int s_s2n_handler_process_read_message( if (slot->adj_right) { downstream_window = aws_channel_slot_downstream_read_window(slot); } + int shutdown_error_code = 0; size_t processed = 0; AWS_LOGF_TRACE( @@ -577,8 +583,7 @@ static int s_s2n_handler_process_read_message( (void *)handler, s2n_connection_get_alert(s2n_handler->connection)); aws_mem_release(outgoing_read_message->allocator, outgoing_read_message); - aws_channel_shutdown(slot->channel, AWS_OP_SUCCESS); - return AWS_OP_SUCCESS; + goto shutdown_channel; } if (read < 0) { @@ -586,6 +591,10 @@ static int s_s2n_handler_process_read_message( /* the socket blocked so exit from the loop */ if (s2n_error_get_type(s2n_errno) == S2N_ERR_T_BLOCKED) { + if (s2n_handler->read_state == AWS_TLS_HANDLER_READ_SHUTTING_DOWN) { + /* Propagate the shutdown as we blocked now. */ + goto shutdown_channel; + } break; } @@ -596,8 +605,8 @@ static int s_s2n_handler_process_read_message( (void *)handler, s2n_strerror(s2n_errno, "EN"), s2n_strerror_debug(s2n_errno, "EN")); - aws_channel_shutdown(slot->channel, AWS_IO_TLS_ERROR_READ_FAILURE); - return AWS_OP_SUCCESS; + shutdown_error_code = AWS_IO_TLS_ERROR_READ_FAILURE; + goto shutdown_channel; }; /* if read > 0 */ @@ -622,6 +631,20 @@ static int s_s2n_handler_process_read_message( (unsigned long long)downstream_window - processed); return AWS_OP_SUCCESS; + +shutdown_channel: + if (s2n_handler->read_state == AWS_TLS_HANDLER_READ_SHUTTING_DOWN) { + if (s2n_handler->shutdown_error_code != 0) { + /* Propagate the original error code if it is set. */ + shutdown_error_code = s2n_handler->shutdown_error_code; + } + s2n_handler->read_state = AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE; + aws_channel_slot_on_handler_shutdown_complete(slot, AWS_CHANNEL_DIR_READ, shutdown_error_code, false); + } else { + /* Starts the shutdown process */ + aws_channel_shutdown(slot->channel, shutdown_error_code); + } + return AWS_OP_SUCCESS; } static int s_s2n_handler_process_write_message( @@ -668,10 +691,7 @@ static void s_delayed_shutdown_task_fn(struct aws_channel_task *channel_task, vo s2n_shutdown(s2n_handler->connection, &blocked); } aws_channel_slot_on_handler_shutdown_complete( - s2n_handler->delayed_shutdown_task.slot, - AWS_CHANNEL_DIR_WRITE, - s2n_handler->delayed_shutdown_task.error, - false); + s2n_handler->slot, AWS_CHANNEL_DIR_WRITE, s2n_handler->shutdown_error_code, false); } static enum aws_tls_signature_algorithm s_s2n_to_aws_signature_algorithm(s2n_tls_signature_algorithm s2n_alg) { @@ -979,8 +999,7 @@ static int s_s2n_do_delayed_shutdown( int error_code) { struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; - s2n_handler->delayed_shutdown_task.slot = slot; - s2n_handler->delayed_shutdown_task.error = error_code; + s2n_handler->shutdown_error_code = error_code; uint64_t shutdown_delay = s2n_connection_get_delay(s2n_handler->connection); uint64_t now = 0; @@ -990,11 +1009,56 @@ static int s_s2n_do_delayed_shutdown( } uint64_t shutdown_time = aws_add_u64_saturating(shutdown_delay, now); - aws_channel_schedule_task_future(slot->channel, &s2n_handler->delayed_shutdown_task.task, shutdown_time); + aws_channel_schedule_task_future(slot->channel, &s2n_handler->delayed_shutdown_task, shutdown_time); return AWS_OP_SUCCESS; } +static void s_run_read(struct aws_channel_task *task, void *arg, aws_task_status status) { + task->task_fn = NULL; + task->arg = NULL; + + if (status == AWS_TASK_STATUS_RUN_READY) { + struct aws_channel_handler *handler = (struct aws_channel_handler *)arg; + struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; + s2n_handler->read_task_pending = false; + s_s2n_handler_process_read_message(handler, s2n_handler->slot, NULL); + } +} + +static void s_initialize_read_delay_shutdown( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + int error_code) { + struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; + /** + * In case of if we have any queued data in the handler after negotiation and we start to shutdown, + * make sure we pass those data down the pipeline before we complete the shutdown. + */ + AWS_LOGF_DEBUG( + AWS_LS_IO_TLS, + "id=%p: TLS handler still have pending data to be delivered during shutdown. Wait until downstream " + "reads the data.", + (void *)handler); + if (aws_channel_slot_downstream_read_window(slot) == 0) { + AWS_LOGF_WARN( + AWS_LS_IO_TLS, + "id=%p: TLS shutdown delayed. Pending data cannot be processed until the flow-control window opens. " + " Your application may hang if the read window never opens", + (void *)handler); + } + s2n_handler->read_state = AWS_TLS_HANDLER_READ_SHUTTING_DOWN; + s2n_handler->shutdown_error_code = error_code; + if (!s2n_handler->read_task_pending) { + /* Kick off read, in case data arrives with TLS negotiation. Shutdown starts right after negotiation. + * Nothing will kick off read in that case. */ + s2n_handler->read_task_pending = true; + aws_channel_task_init( + &s2n_handler->read_task, s_run_read, handler, "s2n_channel_handler_read_on_delay_shutdown"); + aws_channel_schedule_task_now(slot->channel, &s2n_handler->read_task); + } +} + static int s_s2n_handler_shutdown( struct aws_channel_handler *handler, struct aws_channel_slot *slot, @@ -1003,14 +1067,7 @@ static int s_s2n_handler_shutdown( bool abort_immediately) { struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; - if (dir == AWS_CHANNEL_DIR_WRITE) { - if (!abort_immediately && error_code != AWS_IO_SOCKET_CLOSED) { - AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "id=%p: Scheduling delayed write direction shutdown", (void *)handler); - if (s_s2n_do_delayed_shutdown(handler, slot, error_code) == AWS_OP_SUCCESS) { - return AWS_OP_SUCCESS; - } - } - } else { + if (dir == AWS_CHANNEL_DIR_READ) { AWS_LOGF_DEBUG( AWS_LS_IO_TLS, "id=%p: Shutting down read direction with error code %d", (void *)handler, error_code); @@ -1019,33 +1076,39 @@ static int s_s2n_handler_shutdown( s2n_handler->state = NEGOTIATION_FAILED; } - while (!aws_linked_list_empty(&s2n_handler->input_queue)) { - struct aws_linked_list_node *node = aws_linked_list_pop_front(&s2n_handler->input_queue); - struct aws_io_message *message = AWS_CONTAINER_OF(node, struct aws_io_message, queueing_handle); - aws_mem_release(message->allocator, message); + if (!abort_immediately && s2n_handler->state == NEGOTIATION_SUCCEEDED && + !aws_linked_list_empty(&s2n_handler->input_queue) && slot->adj_right) { + s_initialize_read_delay_shutdown(handler, slot, error_code); + return AWS_OP_SUCCESS; + } + s2n_handler->read_state = AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE; + } else { + /* Shutdown in write direction */ + if (!abort_immediately && error_code != AWS_IO_SOCKET_CLOSED) { + AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "id=%p: Scheduling delayed write direction shutdown", (void *)handler); + if (s_s2n_do_delayed_shutdown(handler, slot, error_code) == AWS_OP_SUCCESS) { + return AWS_OP_SUCCESS; + } } } + while (!aws_linked_list_empty(&s2n_handler->input_queue)) { + struct aws_linked_list_node *node = aws_linked_list_pop_front(&s2n_handler->input_queue); + struct aws_io_message *message = AWS_CONTAINER_OF(node, struct aws_io_message, queueing_handle); + aws_mem_release(message->allocator, message); + } return aws_channel_slot_on_handler_shutdown_complete(slot, dir, error_code, abort_immediately); } -static void s_run_read(struct aws_channel_task *task, void *arg, aws_task_status status) { - task->task_fn = NULL; - task->arg = NULL; - - if (status == AWS_TASK_STATUS_RUN_READY) { - struct aws_channel_handler *handler = (struct aws_channel_handler *)arg; - struct s2n_handler *s2n_handler = (struct s2n_handler *)handler->impl; - s_s2n_handler_process_read_message(handler, s2n_handler->slot, NULL); - } -} - static int s_s2n_handler_increment_read_window( struct aws_channel_handler *handler, struct aws_channel_slot *slot, size_t size) { (void)size; struct s2n_handler *s2n_handler = handler->impl; + if (s2n_handler->read_state == AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE) { + return AWS_OP_SUCCESS; + } size_t downstream_size = aws_channel_slot_downstream_read_window(slot); size_t current_window_size = slot->window_size; @@ -1067,16 +1130,17 @@ static int s_s2n_handler_increment_read_window( aws_channel_slot_increment_read_window(slot, window_update_size); } - if (s2n_handler->state == NEGOTIATION_SUCCEEDED && !s2n_handler->sequential_tasks.node.next) { + if (s2n_handler->state == NEGOTIATION_SUCCEEDED && !s2n_handler->read_task_pending) { /* TLS requires full records before it can decrypt anything. As a result we need to check everything we've * buffered instead of just waiting on a read from the socket, or we'll hit a deadlock. * - * We have messages in a queue and they need to be run after the socket has popped (even if it didn't have data - * to read). Alternatively, s2n reads entire records at a time, so we'll need to grab whatever we can and we - * have no idea what's going on inside there. So we need to attempt another read.*/ + * We have messages in a queue and they need to be run after the socket has popped (even if it didn't have + * data to read). Alternatively, s2n reads entire records at a time, so we'll need to grab whatever we can + * and we have no idea what's going on inside there. So we need to attempt another read.*/ + s2n_handler->read_task_pending = true; aws_channel_task_init( - &s2n_handler->sequential_tasks, s_run_read, handler, "s2n_channel_handler_read_on_window_increment"); - aws_channel_schedule_task_now(slot->channel, &s2n_handler->sequential_tasks); + &s2n_handler->read_task, s_run_read, handler, "s2n_channel_handler_read_on_window_increment"); + aws_channel_schedule_task_now(slot->channel, &s2n_handler->read_task); } return AWS_OP_SUCCESS; @@ -1302,10 +1366,7 @@ static struct aws_channel_handler *s_new_tls_handler( } aws_channel_task_init( - &s2n_handler->delayed_shutdown_task.task, - s_delayed_shutdown_task_fn, - &s2n_handler->handler, - "s2n_delayed_shutdown"); + &s2n_handler->delayed_shutdown_task, s_delayed_shutdown_task_fn, &s2n_handler->handler, "s2n_delayed_shutdown"); if (s_s2n_tls_channel_handler_schedule_thread_local_cleanup(slot)) { goto cleanup_conn; @@ -1464,7 +1525,8 @@ static struct aws_tls_ctx *s_tls_ctx_new( switch (options->cipher_pref) { case AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT: - /* No-Op, if the user configured a minimum_tls_version then a version-specific Cipher Preference was set */ + /* No-Op, if the user configured a minimum_tls_version then a version-specific Cipher Preference was set + */ break; case AWS_IO_TLS_CIPHER_PREF_PQ_TLSv1_0_2021_05: security_policy = "PQ-TLS-1-0-2021-05-26"; diff --git a/source/socket_channel_handler.c b/source/socket_channel_handler.c index 27f788a28..e8c9c5499 100644 --- a/source/socket_channel_handler.c +++ b/source/socket_channel_handler.c @@ -208,8 +208,7 @@ static void s_do_read(struct socket_handler *socket_handler) { } } -/* the socket is either readable or errored out. If it's readable, kick off s_do_read() to do its thing. - * If an error, start the channel shutdown process. */ +/* the socket is either readable or errored out. If it's readable, kick off s_do_read() to do its thing. */ static void s_on_readable_notification(struct aws_socket *socket, int error_code, void *user_data) { (void)socket; diff --git a/source/windows/iocp/pipe.c b/source/windows/iocp/pipe.c index 04145c679..a534c7e20 100644 --- a/source/windows/iocp/pipe.c +++ b/source/windows/iocp/pipe.c @@ -251,7 +251,7 @@ int aws_pipe_init( } } - int err = aws_event_loop_connect_handle_to_io_completion_port(write_end_event_loop, &write_impl->handle); + int err = aws_event_loop_connect_handle_to_completion_port(write_end_event_loop, &write_impl->handle); if (err) { goto clean_up; } @@ -282,7 +282,7 @@ int aws_pipe_init( goto clean_up; } - err = aws_event_loop_connect_handle_to_io_completion_port(read_end_event_loop, &read_impl->handle); + err = aws_event_loop_connect_handle_to_completion_port(read_end_event_loop, &read_impl->handle); if (err) { goto clean_up; } diff --git a/source/windows/iocp/socket.c b/source/windows/iocp/socket.c index 0378a183e..6039abb0c 100644 --- a/source/windows/iocp/socket.c +++ b/source/windows/iocp/socket.c @@ -2555,7 +2555,7 @@ int aws_socket_assign_to_event_loop(struct aws_socket *socket, struct aws_event_ } socket->event_loop = event_loop; - return aws_event_loop_connect_handle_to_io_completion_port(event_loop, &socket->io_handle); + return aws_event_loop_connect_handle_to_completion_port(event_loop, &socket->io_handle); } struct aws_event_loop *aws_socket_get_event_loop(struct aws_socket *socket) { diff --git a/source/windows/secure_channel_tls_handler.c b/source/windows/secure_channel_tls_handler.c index b3b42490e..3b0419919 100644 --- a/source/windows/secure_channel_tls_handler.c +++ b/source/windows/secure_channel_tls_handler.c @@ -53,7 +53,7 @@ struct secure_channel_ctx { struct aws_tls_ctx ctx; struct aws_string *alpn_list; SCHANNEL_CRED credentials; - PCERT_CONTEXT pcerts; + PCCERT_CONTEXT pcerts; HCERTSTORE cert_store; HCERTSTORE custom_trust_store; HCRYPTPROV crypto_provider; @@ -103,6 +103,10 @@ struct secure_channel_handler { bool advertise_alpn_message; bool negotiation_finished; bool verify_peer; + struct aws_channel_task read_task; + bool read_task_pending; + enum aws_tls_handler_read_state read_state; + int shutdown_error_code; }; static size_t s_message_overhead(struct aws_channel_handler *handler) { @@ -188,7 +192,7 @@ static int s_manually_verify_peer_cert(struct aws_channel_handler *handler) { int result = AWS_OP_ERR; CERT_CONTEXT *peer_certificate = NULL; HCERTCHAINENGINE engine = NULL; - CERT_CHAIN_CONTEXT *cert_chain_ctx = NULL; + PCCERT_CHAIN_CONTEXT cert_chain_ctx = NULL; /* get the peer's certificate so we can validate it.*/ SECURITY_STATUS status = @@ -1055,7 +1059,7 @@ static int s_do_client_side_negotiation_step_2(struct aws_channel_handler *handl static int s_do_application_data_decrypt(struct aws_channel_handler *handler) { struct secure_channel_handler *sc_handler = handler->impl; - /* I know this is an unncessary initialization, it's initialized here to make linters happy.*/ + /* I know this is an unnecessary initialization, it's initialized here to make linters happy.*/ int error = AWS_OP_ERR; /* when we get an Extra buffer we have to move the pointer and replay the buffer, so we loop until we don't have any extra buffers left over, in the last phase, we then go ahead and send the output. This state function will @@ -1098,8 +1102,7 @@ static int s_do_application_data_decrypt(struct aws_channel_handler *handler) { struct aws_byte_cursor to_append = aws_byte_cursor_from_array(input_buffers[1].pvBuffer, decrypted_length); int append_failed = aws_byte_buf_append(&sc_handler->buffered_read_out_data_buf, &to_append); - AWS_ASSERT(!append_failed); - (void)append_failed; + AWS_FATAL_ASSERT(!append_failed); /* if we have extra we have to move the pointer and do another Decrypt operation. */ if (input_buffers[3].BufferType == SECBUFFER_EXTRA) { @@ -1160,8 +1163,12 @@ static int s_do_application_data_decrypt(struct aws_channel_handler *handler) { static int s_process_pending_output_messages(struct aws_channel_handler *handler) { struct secure_channel_handler *sc_handler = handler->impl; + if (sc_handler->read_state == AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE) { + return AWS_OP_SUCCESS; + } size_t downstream_window = SIZE_MAX; + int error_code = 0; if (sc_handler->slot->adj_right) { downstream_window = aws_channel_slot_downstream_read_window(sc_handler->slot); @@ -1200,7 +1207,8 @@ static int s_process_pending_output_messages(struct aws_channel_handler *handler } if (aws_channel_slot_send_message(sc_handler->slot, read_out_msg, AWS_CHANNEL_DIR_READ)) { aws_mem_release(read_out_msg->allocator, read_out_msg); - return AWS_OP_ERR; + error_code = aws_last_error(); + goto done; } if (sc_handler->slot->adj_right) { @@ -1216,17 +1224,36 @@ static int s_process_pending_output_messages(struct aws_channel_handler *handler } } - return AWS_OP_SUCCESS; + if (sc_handler->buffered_read_out_data_buf.len > 0) { + /* Still have more data to be delivered */ + return AWS_OP_SUCCESS; + } + +done: + if (sc_handler->read_state == AWS_TLS_HANDLER_READ_SHUTTING_DOWN) { + sc_handler->read_state = AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE; + /* Continue the shutdown process delayed before. */ + + /* Propagate the original error code if it is set. */ + int shutdown_error_code = sc_handler->shutdown_error_code ? sc_handler->shutdown_error_code : error_code; + + aws_channel_slot_on_handler_shutdown_complete( + sc_handler->slot, AWS_CHANNEL_DIR_READ, shutdown_error_code, false); + } + + /* If there was an error, re-raise it, in case some other function call modified aws_last_error() */ + return error_code ? aws_raise_error(error_code) : AWS_OP_SUCCESS; } static void s_process_pending_output_task(struct aws_channel_task *task, void *arg, enum aws_task_status status) { (void)task; struct aws_channel_handler *handler = arg; + struct secure_channel_handler *sc_handler = handler->impl; + sc_handler->read_task_pending = false; aws_channel_task_init(task, NULL, NULL, "secure_channel_handler_process_pending_output"); if (status == AWS_TASK_STATUS_RUN_READY) { if (s_process_pending_output_messages(handler)) { - struct secure_channel_handler *sc_handler = arg; aws_channel_shutdown(sc_handler->slot->channel, aws_last_error()); } } @@ -1237,70 +1264,71 @@ static int s_process_read_message( struct aws_channel_slot *slot, struct aws_io_message *message) { + AWS_ASSERT(message); struct secure_channel_handler *sc_handler = handler->impl; - if (message) { - /* note, most of these functions log internally, so the log messages in this function are sparse. */ - AWS_LOGF_TRACE( - AWS_LS_IO_TLS, - "id=%p: processing incoming message of size %zu", - (void *)handler, - message->message_data.len); - - struct aws_byte_cursor message_cursor = aws_byte_cursor_from_buf(&message->message_data); + if (sc_handler->read_state == AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE) { + aws_mem_release(message->allocator, message); + return AWS_OP_SUCCESS; + } - /* The SSPI interface forces us to manage incomplete records manually. So when we had extra after - the previous read, it needs to be shifted to the beginning of the current read, then the current - read data is appended to it. If we had an incomplete record, we don't need to shift anything but - we do need to append the current read data to the end of the incomplete record from the previous read. - Keep going until we've processed everything in the message we were just passed. - */ - int err = AWS_OP_SUCCESS; - while (!err && message_cursor.len) { + /* note, most of these functions log internally, so the log messages in this function are sparse. */ + AWS_LOGF_TRACE( + AWS_LS_IO_TLS, "id=%p: processing incoming message of size %zu", (void *)handler, message->message_data.len); - size_t available_buffer_space = - sc_handler->buffered_read_in_data_buf.capacity - sc_handler->buffered_read_in_data_buf.len; - size_t available_message_len = message_cursor.len; - size_t amount_to_move_to_buffer = - available_buffer_space > available_message_len ? available_message_len : available_buffer_space; + struct aws_byte_cursor message_cursor = aws_byte_cursor_from_buf(&message->message_data); - memcpy( - sc_handler->buffered_read_in_data_buf.buffer + sc_handler->buffered_read_in_data_buf.len, - message_cursor.ptr, - amount_to_move_to_buffer); - sc_handler->buffered_read_in_data_buf.len += amount_to_move_to_buffer; + /* The SSPI interface forces us to manage incomplete records manually. So when we had extra after + the previous read, it needs to be shifted to the beginning of the current read, then the current + read data is appended to it. If we had an incomplete record, we don't need to shift anything but + we do need to append the current read data to the end of the incomplete record from the previous read. + Keep going until we've processed everything in the message we were just passed. + */ + int err = AWS_OP_SUCCESS; + while (!err && message_cursor.len) { + /* copy as much data as possible into buffered_read_in_data_buf */ + aws_byte_buf_write_to_capacity(&sc_handler->buffered_read_in_data_buf, &message_cursor); + + /* decrypt */ + bool record_is_incomplete = false; + err = sc_handler->s_connection_state_fn(handler); + if (err) { + /* AWS_IO_READ_WOULD_BLOCK isn't fatal, it just means the record is incomplete */ + if (aws_last_error() == AWS_IO_READ_WOULD_BLOCK) { + err = AWS_OP_SUCCESS; + record_is_incomplete = true; + } else { + break; + } + } - err = sc_handler->s_connection_state_fn(handler); + /* if any data was decrypted, try to send it downstream */ + if (sc_handler->buffered_read_out_data_buf.len) { + err = s_process_pending_output_messages(handler); + if (err) { + break; + } + } - if (err && aws_last_error() == AWS_IO_READ_WOULD_BLOCK) { - if (sc_handler->buffered_read_in_data_buf.len == sc_handler->buffered_read_in_data_buf.capacity) { - /* throw this one as a protocol error. */ - aws_raise_error(AWS_IO_TLS_ERROR_WRITE_FAILURE); - } else { - if (sc_handler->buffered_read_out_data_buf.len) { - err = s_process_pending_output_messages(handler); - if (err) { - break; - } - } - /* prevent a deadlock due to downstream handlers wanting more data, but we have an incomplete - record, and the amount they're requesting is less than the size of a tls record. */ - size_t window_size = slot->window_size; - if (!window_size && - aws_channel_slot_increment_read_window(slot, sc_handler->estimated_incomplete_size)) { - err = AWS_OP_ERR; - } else { - sc_handler->estimated_incomplete_size = 0; - err = AWS_OP_SUCCESS; - } - } - aws_byte_cursor_advance(&message_cursor, amount_to_move_to_buffer); - continue; - } else if (err) { + if (record_is_incomplete) { + /* if our buffer is full, but the record is still incomplete ... throw this one as a protocol error. */ + if (sc_handler->buffered_read_in_data_buf.len == sc_handler->buffered_read_in_data_buf.capacity) { + err = aws_raise_error(AWS_IO_TLS_ERROR_WRITE_FAILURE); break; } - /* handle any left over extra data from the decrypt operation here. */ + /* prevent a deadlock due to downstream handlers wanting more data, but we have an incomplete + record, and the amount they're requesting is less than the size of a tls record. */ + size_t downstream_window = + sc_handler->slot->adj_right ? aws_channel_slot_downstream_read_window(sc_handler->slot) : SIZE_MAX; + if (downstream_window > 0 && slot->window_size == 0) { + err = aws_channel_slot_increment_read_window(slot, sc_handler->estimated_incomplete_size); + if (err) { + break; + } + } + } else { + /* we had a complete record. handle any left over extra data from the decrypt operation here. */ if (sc_handler->read_extra) { size_t move_pos = sc_handler->buffered_read_in_data_buf.len - sc_handler->read_extra; memmove( @@ -1312,32 +1340,15 @@ static int s_process_read_message( } else { sc_handler->buffered_read_in_data_buf.len = 0; } - - if (sc_handler->buffered_read_out_data_buf.len) { - err = s_process_pending_output_messages(handler); - if (err) { - break; - } - } - aws_byte_cursor_advance(&message_cursor, amount_to_move_to_buffer); - } - - if (!err) { - aws_mem_release(message->allocator, message); - return AWS_OP_SUCCESS; } + } + if (err) { aws_channel_shutdown(slot->channel, aws_last_error()); return AWS_OP_ERR; } - if (sc_handler->buffered_read_out_data_buf.len) { - if (s_process_pending_output_messages(handler)) { - return AWS_OP_ERR; - } - aws_mem_release(message->allocator, message); - } - + aws_mem_release(message->allocator, message); return AWS_OP_SUCCESS; } @@ -1455,6 +1466,9 @@ static int s_process_write_message( static int s_increment_read_window(struct aws_channel_handler *handler, struct aws_channel_slot *slot, size_t size) { (void)size; struct secure_channel_handler *sc_handler = handler->impl; + if (sc_handler->read_state == AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE) { + return AWS_OP_SUCCESS; + } AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: Increment read window message received %zu", (void *)handler, size); /* You can't query a context if negotiation isn't completed, since ciphers haven't been negotiated @@ -1492,13 +1506,17 @@ static int s_increment_read_window(struct aws_channel_handler *handler, struct a aws_channel_slot_increment_read_window(slot, window_update_size); } - if (sc_handler->negotiation_finished && !sc_handler->sequential_task_storage.task_fn) { + if (sc_handler->negotiation_finished && !sc_handler->read_task_pending) { + /* TLS requires full records before it can decrypt anything. As a result we need to check everything we've + * buffered instead of just waiting on a read from the socket, or we'll hit a deadlock. + */ + sc_handler->read_task_pending = true; aws_channel_task_init( - &sc_handler->sequential_task_storage, + &sc_handler->read_task, s_process_pending_output_task, handler, "secure_channel_handler_process_pending_output_on_window_increment"); - aws_channel_schedule_task_now(slot->channel, &sc_handler->sequential_task_storage); + aws_channel_schedule_task_now(slot->channel, &sc_handler->read_task); } return AWS_OP_SUCCESS; } @@ -1511,6 +1529,43 @@ static size_t s_initial_window_size(struct aws_channel_handler *handler) { return EST_HANDSHAKE_SIZE; } +static void s_initialize_read_delay_shutdown( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + int error_code) { + struct secure_channel_handler *sc_handler = handler->impl; + /** + * In case of if we have any queued data in the handler after negotiation and we start to shutdown, + * make sure we pass those data down the pipeline before we complete the shutdown. + */ + AWS_LOGF_DEBUG( + AWS_LS_IO_TLS, + "id=%p: TLS handler still have pending data to be delivered during shutdown. Wait until downstream " + "reads the data.", + (void *)handler); + if (aws_channel_slot_downstream_read_window(slot) == 0) { + AWS_LOGF_WARN( + AWS_LS_IO_TLS, + "id=%p: TLS shutdown delayed. Pending data cannot be processed until the flow-control window opens. " + " Your application may hang if the read window never opens", + (void *)handler); + } + sc_handler->read_state = AWS_TLS_HANDLER_READ_SHUTTING_DOWN; + sc_handler->shutdown_error_code = error_code; + if (!sc_handler->read_task_pending) { + /* Kick off read, in case data arrives with TLS negotiation. Shutdown starts right after negotiation. + * Nothing will kick off read in that case. */ + sc_handler->read_task_pending = true; + aws_channel_task_init( + &sc_handler->read_task, + s_process_pending_output_task, + handler, + "secure_channel_handler_read_on_delay_shutdown"); + + aws_channel_schedule_task_now(slot->channel, &sc_handler->read_task); + } +} + static int s_handler_shutdown( struct aws_channel_handler *handler, struct aws_channel_slot *slot, @@ -1531,7 +1586,17 @@ static int s_handler_shutdown( .pBuffers = &output_buffer, }; - if (dir == AWS_CHANNEL_DIR_WRITE) { + if (dir == AWS_CHANNEL_DIR_READ) { + AWS_LOGF_DEBUG( + AWS_LS_IO_TLS, "id=%p: shutting down read direction with error %d.", (void *)handler, error_code); + if (!abort_immediately && sc_handler->negotiation_finished && sc_handler->buffered_read_out_data_buf.len && + slot->adj_right) { + s_initialize_read_delay_shutdown(handler, slot, error_code); + return AWS_OP_SUCCESS; + } + sc_handler->read_state = AWS_TLS_HANDLER_READ_SHUT_DOWN_COMPLETE; + } else { + /* Shutdown in write direction */ if (!abort_immediately && error_code != AWS_IO_SOCKET_CLOSED) { AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "id=%p: Shutting down the write direction", (void *)handler); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 13132d659..534b197d9 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -237,6 +237,8 @@ if(NOT BYO_CRYPTO) add_net_test_case(test_concurrent_cert_import) add_net_test_case(test_duplicate_cert_import) add_net_test_case(tls_channel_echo_and_backpressure_test) + add_net_test_case(tls_channel_shutdown_with_cache_test) + add_net_test_case(tls_channel_shutdown_with_cache_window_update_after_shutdown_test) add_net_test_case(tls_client_channel_negotiation_error_socket_closed) add_net_test_case(tls_client_channel_negotiation_success) add_net_test_case(tls_server_multiple_connections) diff --git a/tests/event_loop_test.c b/tests/event_loop_test.c index 191ea7fb1..659f313c6 100644 --- a/tests/event_loop_test.c +++ b/tests/event_loop_test.c @@ -286,7 +286,7 @@ static int s_test_event_loop_completion_events(struct aws_allocator *allocator, ASSERT_SUCCESS(s_async_pipe_init(&read_handle, &write_handle)); /* Connect to event-loop */ - ASSERT_SUCCESS(aws_event_loop_connect_handle_to_io_completion_port(event_loop, &write_handle)); + ASSERT_SUCCESS(aws_event_loop_connect_handle_to_completion_port(event_loop, &write_handle)); /* Set up an async (overlapped) write that will result in s_on_overlapped_operation_complete() getting run * and filling out `completion_data` */ @@ -1057,7 +1057,7 @@ static int s_event_loop_test_multiple_stops(struct aws_allocator *allocator, voi ASSERT_NOT_NULL(event_loop, "Event loop creation failed with error: %s", aws_error_debug_str(aws_last_error())); ASSERT_SUCCESS(aws_event_loop_run(event_loop)); - for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { ASSERT_SUCCESS(aws_event_loop_stop(event_loop)); } aws_event_loop_destroy(event_loop); diff --git a/tests/socket_test.c b/tests/socket_test.c index 52aac70bf..60ee5e4ab 100644 --- a/tests/socket_test.c +++ b/tests/socket_test.c @@ -477,7 +477,7 @@ static int s_test_socket_with_bind_to_interface(struct aws_allocator *allocator, #else strncpy(options.network_interface_name, "lo", AWS_NETWORK_INTERFACE_NAME_MAX); #endif - struct aws_socket_endpoint endpoint = {.address = "127.0.0.1", .port = 8127}; + struct aws_socket_endpoint endpoint = {.address = "127.0.0.1", .port = 8128}; if (s_test_socket(allocator, &options, &endpoint)) { #if !defined(AWS_OS_APPLE) && !defined(AWS_OS_LINUX) if (aws_last_error() == AWS_ERROR_PLATFORM_NOT_SUPPORTED) { @@ -490,7 +490,7 @@ static int s_test_socket_with_bind_to_interface(struct aws_allocator *allocator, options.domain = AWS_SOCKET_IPV4; ASSERT_SUCCESS(s_test_socket(allocator, &options, &endpoint)); - struct aws_socket_endpoint endpoint_ipv6 = {.address = "::1", .port = 1024}; + struct aws_socket_endpoint endpoint_ipv6 = {.address = "::1", .port = 8129}; options.type = AWS_SOCKET_STREAM; options.domain = AWS_SOCKET_IPV6; if (s_test_socket(allocator, &options, &endpoint_ipv6)) { diff --git a/tests/tls_handler_test.c b/tests/tls_handler_test.c index e6db170a7..0b0f5c88c 100644 --- a/tests/tls_handler_test.c +++ b/tests/tls_handler_test.c @@ -167,6 +167,7 @@ static int s_tls_test_arg_init( } static int s_tls_common_tester_init(struct aws_allocator *allocator, struct tls_common_tester *tester) { + aws_io_library_init(allocator); AWS_ZERO_STRUCT(*tester); struct aws_mutex mutex = AWS_MUTEX_INIT; @@ -329,6 +330,9 @@ static void s_tls_handler_test_client_shutdown_callback( aws_mutex_lock(setup_test_args->mutex); setup_test_args->shutdown_finished = true; + if (error_code) { + setup_test_args->last_error_code = error_code; + } aws_mutex_unlock(setup_test_args->mutex); aws_condition_variable_notify_one(setup_test_args->condition_variable); } @@ -347,8 +351,11 @@ static void s_tls_handler_test_server_shutdown_callback( aws_mutex_lock(setup_test_args->mutex); setup_test_args->shutdown_finished = true; - aws_mutex_unlock(setup_test_args->mutex); + if (error_code) { + setup_test_args->last_error_code = error_code; + } aws_condition_variable_notify_one(setup_test_args->condition_variable); + aws_mutex_unlock(setup_test_args->mutex); } static void s_tls_handler_test_server_listener_destroy_callback( @@ -359,9 +366,8 @@ static void s_tls_handler_test_server_listener_destroy_callback( struct tls_test_args *setup_test_args = (struct tls_test_args *)user_data; aws_mutex_lock(setup_test_args->mutex); setup_test_args->listener_destroyed = true; + aws_condition_variable_notify_all(setup_test_args->condition_variable); aws_mutex_unlock(setup_test_args->mutex); - - aws_condition_variable_notify_one(setup_test_args->condition_variable); } static void s_tls_on_negotiated( @@ -497,74 +503,120 @@ static struct aws_byte_buf s_tls_test_handle_write( return (struct aws_byte_buf){0}; } -static int s_tls_channel_echo_and_backpressure_test_fn(struct aws_allocator *allocator, void *ctx) { - (void)ctx; - aws_io_library_init(allocator); - ASSERT_SUCCESS(s_tls_common_tester_init(allocator, &c_tester)); +static uint8_t s_server_received_message[128] = {0}; +static uint8_t s_client_received_message[128] = {0}; - struct aws_byte_buf read_tag = aws_byte_buf_from_c_str("I'm a little teapot."); - struct aws_byte_buf write_tag = aws_byte_buf_from_c_str("I'm a big teapot"); +/* common structure for test with self-initaizlied server and client */ +struct tls_channel_server_client_tester { + struct tls_test_rw_args client_rw_args; + struct tls_test_rw_args server_rw_args; + struct tls_test_args client_args; + struct tls_test_args server_args; + struct aws_client_bootstrap *client_bootstrap; + struct tls_local_server_tester local_server_tester; - uint8_t incoming_received_message[128] = {0}; - uint8_t outgoing_received_message[128] = {0}; + struct aws_mutex server_mutex; + struct aws_condition_variable server_condition_variable; + + struct aws_atomic_var server_shutdown_invoked; + /* Make sure server and client doesn't use the same thread */ + struct aws_event_loop_group *client_el_group; + + bool window_update_after_shutdown; +}; + +static struct tls_channel_server_client_tester s_server_client_tester; + +static int s_tls_channel_server_client_tester_init(struct aws_allocator *allocator) { + ASSERT_SUCCESS(s_tls_common_tester_init(allocator, &c_tester)); + AWS_ZERO_STRUCT(s_server_client_tester); + ASSERT_SUCCESS(aws_mutex_init(&s_server_client_tester.server_mutex)); + ASSERT_SUCCESS(aws_condition_variable_init(&s_server_client_tester.server_condition_variable)); + s_server_client_tester.client_el_group = aws_event_loop_group_new_default(allocator, 0, NULL); - struct tls_test_rw_args incoming_rw_args; ASSERT_SUCCESS(s_tls_rw_args_init( - &incoming_rw_args, + &s_server_client_tester.server_rw_args, &c_tester, - aws_byte_buf_from_empty_array(incoming_received_message, sizeof(incoming_received_message)))); - - struct tls_test_rw_args outgoing_rw_args; + aws_byte_buf_from_empty_array(s_server_received_message, sizeof(s_server_received_message)))); + s_server_client_tester.server_rw_args.mutex = &s_server_client_tester.server_mutex; + s_server_client_tester.server_rw_args.condition_variable = &s_server_client_tester.server_condition_variable; ASSERT_SUCCESS(s_tls_rw_args_init( - &outgoing_rw_args, + &s_server_client_tester.client_rw_args, &c_tester, - aws_byte_buf_from_empty_array(outgoing_received_message, sizeof(outgoing_received_message)))); + aws_byte_buf_from_empty_array(s_client_received_message, sizeof(s_client_received_message)))); + ASSERT_SUCCESS(s_tls_test_arg_init(allocator, &s_server_client_tester.client_args, false, &c_tester)); + ASSERT_SUCCESS(s_tls_test_arg_init(allocator, &s_server_client_tester.server_args, true, &c_tester)); + s_server_client_tester.server_args.mutex = &s_server_client_tester.server_mutex; + s_server_client_tester.server_args.condition_variable = &s_server_client_tester.server_condition_variable; - struct tls_test_args outgoing_args; - ASSERT_SUCCESS(s_tls_test_arg_init(allocator, &outgoing_args, false, &c_tester)); + ASSERT_SUCCESS(s_tls_local_server_tester_init( + allocator, + &s_server_client_tester.local_server_tester, + &s_server_client_tester.server_args, + &c_tester, + true, + "server.crt", + "server.key")); + struct aws_client_bootstrap_options bootstrap_options = { + .event_loop_group = s_server_client_tester.client_el_group, + .host_resolver = c_tester.resolver, + }; + s_server_client_tester.client_bootstrap = aws_client_bootstrap_new(allocator, &bootstrap_options); - struct tls_test_args incoming_args; - ASSERT_SUCCESS(s_tls_test_arg_init(allocator, &incoming_args, true, &c_tester)); + aws_atomic_store_int(&s_server_client_tester.server_shutdown_invoked, 0); + return AWS_OP_SUCCESS; +} - struct tls_local_server_tester local_server_tester; - ASSERT_SUCCESS(s_tls_local_server_tester_init( - allocator, &local_server_tester, &incoming_args, &c_tester, true, "server.crt", "server.key")); - /* make the windows small to make sure back pressure is honored. */ - struct aws_channel_handler *outgoing_rw_handler = rw_handler_new( - allocator, s_tls_test_handle_read, s_tls_test_handle_write, true, write_tag.len / 2, &outgoing_rw_args); - ASSERT_NOT_NULL(outgoing_rw_handler); +static int s_tls_channel_server_client_tester_cleanup(void) { + /* Make sure client and server all shutdown */ + ASSERT_SUCCESS(aws_mutex_lock(&c_tester.mutex)); + ASSERT_SUCCESS(aws_condition_variable_wait_pred( + &c_tester.condition_variable, + &c_tester.mutex, + s_tls_channel_shutdown_predicate, + &s_server_client_tester.client_args)); + aws_mutex_unlock(&c_tester.mutex); - struct aws_channel_handler *incoming_rw_handler = rw_handler_new( - allocator, s_tls_test_handle_read, s_tls_test_handle_write, true, read_tag.len / 2, &incoming_rw_args); - ASSERT_NOT_NULL(incoming_rw_handler); + aws_server_bootstrap_destroy_socket_listener( + s_server_client_tester.local_server_tester.server_bootstrap, + s_server_client_tester.local_server_tester.listener); + ASSERT_SUCCESS(s_tls_local_server_tester_clean_up(&s_server_client_tester.local_server_tester)); + ASSERT_SUCCESS(aws_mutex_lock(&s_server_client_tester.server_mutex)); + ASSERT_SUCCESS(aws_condition_variable_wait_pred( + &s_server_client_tester.server_condition_variable, + &s_server_client_tester.server_mutex, + s_tls_listener_destroy_predicate, + &s_server_client_tester.server_args)); + ASSERT_SUCCESS(aws_mutex_unlock(&s_server_client_tester.server_mutex)); - incoming_args.rw_handler = incoming_rw_handler; - outgoing_args.rw_handler = outgoing_rw_handler; + /* Clean up */ + aws_mutex_clean_up(&s_server_client_tester.server_mutex); + aws_condition_variable_clean_up(&s_server_client_tester.server_condition_variable); + aws_client_bootstrap_release(s_server_client_tester.client_bootstrap); + aws_event_loop_group_release(s_server_client_tester.client_el_group); + ASSERT_SUCCESS(s_tls_common_tester_clean_up(&c_tester)); + return AWS_OP_SUCCESS; +} - g_aws_channel_max_fragment_size = 4096; +static int s_set_socket_channel(struct tls_channel_server_client_tester *server_client_tester) { struct tls_opt_tester client_tls_opt_tester; struct aws_byte_cursor server_name = aws_byte_cursor_from_c_str("localhost"); - ASSERT_SUCCESS(s_tls_client_opt_tester_init(allocator, &client_tls_opt_tester, server_name)); + ASSERT_SUCCESS( + s_tls_client_opt_tester_init(server_client_tester->client_args.allocator, &client_tls_opt_tester, server_name)); aws_tls_connection_options_set_callbacks( - &client_tls_opt_tester.opt, s_tls_on_negotiated, NULL, NULL, &outgoing_args); - - struct aws_client_bootstrap_options bootstrap_options = { - .event_loop_group = c_tester.el_group, - .host_resolver = c_tester.resolver, - }; - struct aws_client_bootstrap *client_bootstrap = aws_client_bootstrap_new(allocator, &bootstrap_options); + &client_tls_opt_tester.opt, s_tls_on_negotiated, NULL, NULL, &server_client_tester->client_args); struct aws_socket_channel_bootstrap_options channel_options; AWS_ZERO_STRUCT(channel_options); - channel_options.bootstrap = client_bootstrap; - channel_options.host_name = local_server_tester.endpoint.address; + channel_options.bootstrap = server_client_tester->client_bootstrap; + channel_options.host_name = server_client_tester->local_server_tester.endpoint.address; channel_options.port = 0; - channel_options.socket_options = &local_server_tester.socket_options; + channel_options.socket_options = &server_client_tester->local_server_tester.socket_options; channel_options.tls_options = &client_tls_opt_tester.opt; channel_options.setup_callback = s_tls_handler_test_client_setup_callback; channel_options.shutdown_callback = s_tls_handler_test_client_shutdown_callback; - channel_options.user_data = &outgoing_args; + channel_options.user_data = &server_client_tester->client_args; channel_options.enable_read_back_pressure = true; ASSERT_SUCCESS(aws_client_bootstrap_new_socket_channel(&channel_options)); @@ -573,11 +625,14 @@ static int s_tls_channel_echo_and_backpressure_test_fn(struct aws_allocator *all * done messed up. */ aws_tls_connection_options_clean_up(&client_tls_opt_tester.opt); /* wait for both ends to setup */ - ASSERT_SUCCESS(aws_mutex_lock(&c_tester.mutex)); + ASSERT_SUCCESS(aws_mutex_lock(&s_server_client_tester.server_mutex)); ASSERT_SUCCESS(aws_condition_variable_wait_pred( - &c_tester.condition_variable, &c_tester.mutex, s_tls_channel_setup_predicate, &incoming_args)); - ASSERT_SUCCESS(aws_mutex_unlock(&c_tester.mutex)); - ASSERT_FALSE(incoming_args.error_invoked); + &s_server_client_tester.server_condition_variable, + &s_server_client_tester.server_mutex, + s_tls_channel_setup_predicate, + &server_client_tester->server_args)); + ASSERT_SUCCESS(aws_mutex_unlock(&s_server_client_tester.server_mutex)); + ASSERT_FALSE(server_client_tester->server_args.error_invoked); /* currently it seems ALPN doesn't work in server mode. Just leaving this check out for now. */ # ifndef __APPLE__ @@ -588,16 +643,19 @@ static int s_tls_channel_echo_and_backpressure_test_fn(struct aws_allocator *all ASSERT_BIN_ARRAYS_EQUALS( expected_protocol.buffer, expected_protocol.len, - incoming_args.negotiated_protocol.buffer, - incoming_args.negotiated_protocol.len); + server_client_tester->server_args.negotiated_protocol.buffer, + server_client_tester->server_args.negotiated_protocol.len); } # endif ASSERT_SUCCESS(aws_mutex_lock(&c_tester.mutex)); ASSERT_SUCCESS(aws_condition_variable_wait_pred( - &c_tester.condition_variable, &c_tester.mutex, s_tls_channel_setup_predicate, &outgoing_args)); + &c_tester.condition_variable, + &c_tester.mutex, + s_tls_channel_setup_predicate, + &server_client_tester->client_args)); ASSERT_SUCCESS(aws_mutex_unlock(&c_tester.mutex)); - ASSERT_FALSE(outgoing_args.error_invoked); + ASSERT_FALSE(server_client_tester->client_args.error_invoked); /* currently it seems ALPN doesn't work in server mode. Just leaving this check out for now. */ # ifndef __MACH__ @@ -605,76 +663,254 @@ static int s_tls_channel_echo_and_backpressure_test_fn(struct aws_allocator *all ASSERT_BIN_ARRAYS_EQUALS( expected_protocol.buffer, expected_protocol.len, - outgoing_args.negotiated_protocol.buffer, - outgoing_args.negotiated_protocol.len); + server_client_tester->client_args.negotiated_protocol.buffer, + server_client_tester->client_args.negotiated_protocol.len); } # endif - ASSERT_FALSE(outgoing_args.error_invoked); + ASSERT_SUCCESS(s_tls_opt_tester_clean_up(&client_tls_opt_tester)); + return AWS_OP_SUCCESS; +} + +static int s_tls_channel_echo_and_backpressure_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + ASSERT_SUCCESS(s_tls_channel_server_client_tester_init(allocator)); + struct tls_test_rw_args *client_rw_args = &s_server_client_tester.client_rw_args; + struct tls_test_rw_args *server_rw_args = &s_server_client_tester.server_rw_args; + struct tls_test_args *client_args = &s_server_client_tester.client_args; + struct tls_test_args *server_args = &s_server_client_tester.server_args; + + struct aws_byte_buf read_tag = aws_byte_buf_from_c_str("I'm a little teapot."); + struct aws_byte_buf write_tag = aws_byte_buf_from_c_str("I'm a big teapot"); + + /* make the windows small to make sure back pressure is honored. */ + struct aws_channel_handler *client_rw_handler = rw_handler_new( + allocator, s_tls_test_handle_read, s_tls_test_handle_write, true, write_tag.len / 2, client_rw_args); + ASSERT_NOT_NULL(client_rw_handler); + struct aws_channel_handler *server_rw_handler = rw_handler_new( + allocator, s_tls_test_handle_read, s_tls_test_handle_write, true, read_tag.len / 2, server_rw_args); + ASSERT_NOT_NULL(server_rw_handler); + server_args->rw_handler = server_rw_handler; + client_args->rw_handler = client_rw_handler; + + g_aws_channel_max_fragment_size = 4096; + ASSERT_SUCCESS(s_set_socket_channel(&s_server_client_tester)); /* Do the IO operations */ - rw_handler_write(outgoing_args.rw_handler, outgoing_args.rw_slot, &write_tag); - rw_handler_write(incoming_args.rw_handler, incoming_args.rw_slot, &read_tag); + rw_handler_write(client_args->rw_handler, client_args->rw_slot, &write_tag); + rw_handler_write(server_args->rw_handler, server_args->rw_slot, &read_tag); ASSERT_SUCCESS(aws_mutex_lock(&c_tester.mutex)); ASSERT_SUCCESS(aws_condition_variable_wait_pred( - &c_tester.condition_variable, &c_tester.mutex, s_tls_test_read_predicate, &incoming_rw_args)); - ASSERT_SUCCESS(aws_condition_variable_wait_pred( - &c_tester.condition_variable, &c_tester.mutex, s_tls_test_read_predicate, &outgoing_rw_args)); + &c_tester.condition_variable, &c_tester.mutex, s_tls_test_read_predicate, client_rw_args)); ASSERT_SUCCESS(aws_mutex_unlock(&c_tester.mutex)); - incoming_rw_args.invocation_happened = false; - outgoing_rw_args.invocation_happened = false; + ASSERT_SUCCESS(aws_mutex_lock(&s_server_client_tester.server_mutex)); + ASSERT_SUCCESS(aws_condition_variable_wait_pred( + &s_server_client_tester.server_condition_variable, + &s_server_client_tester.server_mutex, + s_tls_test_read_predicate, + server_rw_args)); + ASSERT_SUCCESS(aws_mutex_unlock(&s_server_client_tester.server_mutex)); + + server_rw_args->invocation_happened = false; + client_rw_args->invocation_happened = false; - ASSERT_INT_EQUALS(1, outgoing_rw_args.read_invocations); - ASSERT_INT_EQUALS(1, incoming_rw_args.read_invocations); + ASSERT_INT_EQUALS(1, client_rw_args->read_invocations); + ASSERT_INT_EQUALS(1, server_rw_args->read_invocations); /* Go ahead and verify back-pressure works*/ - rw_handler_trigger_increment_read_window(incoming_args.rw_handler, incoming_args.rw_slot, 100); - rw_handler_trigger_increment_read_window(outgoing_args.rw_handler, outgoing_args.rw_slot, 100); + rw_handler_trigger_increment_read_window(server_args->rw_handler, server_args->rw_slot, 100); + rw_handler_trigger_increment_read_window(client_args->rw_handler, client_args->rw_slot, 100); ASSERT_SUCCESS(aws_mutex_lock(&c_tester.mutex)); ASSERT_SUCCESS(aws_condition_variable_wait_pred( - &c_tester.condition_variable, &c_tester.mutex, s_tls_test_read_predicate, &incoming_rw_args)); - ASSERT_SUCCESS(aws_condition_variable_wait_pred( - &c_tester.condition_variable, &c_tester.mutex, s_tls_test_read_predicate, &outgoing_rw_args)); + &c_tester.condition_variable, &c_tester.mutex, s_tls_test_read_predicate, client_rw_args)); ASSERT_SUCCESS(aws_mutex_unlock(&c_tester.mutex)); - ASSERT_INT_EQUALS(2, outgoing_rw_args.read_invocations); - ASSERT_INT_EQUALS(2, incoming_rw_args.read_invocations); + ASSERT_SUCCESS(aws_mutex_lock(&s_server_client_tester.server_mutex)); + ASSERT_SUCCESS(aws_condition_variable_wait_pred( + &s_server_client_tester.server_condition_variable, + &s_server_client_tester.server_mutex, + s_tls_test_read_predicate, + server_rw_args)); + ASSERT_SUCCESS(aws_mutex_unlock(&s_server_client_tester.server_mutex)); + + ASSERT_INT_EQUALS(2, client_rw_args->read_invocations); + ASSERT_INT_EQUALS(2, server_rw_args->read_invocations); ASSERT_BIN_ARRAYS_EQUALS( - write_tag.buffer, - write_tag.len, - incoming_rw_args.received_message.buffer, - incoming_rw_args.received_message.len); + write_tag.buffer, write_tag.len, server_rw_args->received_message.buffer, server_rw_args->received_message.len); ASSERT_BIN_ARRAYS_EQUALS( - read_tag.buffer, read_tag.len, outgoing_rw_args.received_message.buffer, outgoing_rw_args.received_message.len); + read_tag.buffer, read_tag.len, client_rw_args->received_message.buffer, client_rw_args->received_message.len); - aws_channel_shutdown(incoming_args.channel, AWS_OP_SUCCESS); - ASSERT_SUCCESS(aws_mutex_lock(&c_tester.mutex)); + aws_channel_shutdown(server_args->channel, AWS_OP_SUCCESS); + ASSERT_SUCCESS(aws_mutex_lock(&s_server_client_tester.server_mutex)); ASSERT_SUCCESS(aws_condition_variable_wait_pred( - &c_tester.condition_variable, &c_tester.mutex, s_tls_channel_shutdown_predicate, &incoming_args)); - ASSERT_SUCCESS(aws_mutex_unlock(&c_tester.mutex)); + &s_server_client_tester.server_condition_variable, + &s_server_client_tester.server_mutex, + s_tls_channel_shutdown_predicate, + &s_server_client_tester.server_args)); + ASSERT_SUCCESS(aws_mutex_unlock(&s_server_client_tester.server_mutex)); /*no shutdown on the client necessary here (it should have been triggered by shutting down the other side). just * wait for the event to fire. */ + ASSERT_SUCCESS(s_tls_channel_server_client_tester_cleanup()); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(tls_channel_echo_and_backpressure_test, s_tls_channel_echo_and_backpressure_test_fn) + +static struct aws_byte_buf s_on_client_recive_shutdown_with_cache_data( + struct aws_channel_handler *handler, + struct aws_channel_slot *slot, + struct aws_byte_buf *data_read, + void *user_data) { + + /** + * Client received the data from server, and it happens from the channel thread. + * Because of the limited window size, we also have more data cached in the TLS hanlder. + * + * Now: + * - Shutdown the server channel, and wait for it to finish, which will close the socket, and the socket will + * schedule the channel shutdown process when this function returns. + * - Update the window from this thread, it should schedule another task from channel thread to do so. + */ + (void)slot; + (void)user_data; + struct tls_test_rw_args *client_rw_args = &s_server_client_tester.client_rw_args; + + if (!rw_handler_shutdown_called(handler)) { + + size_t shutdown_invoked = aws_atomic_load_int(&s_server_client_tester.server_shutdown_invoked); + if (shutdown_invoked == 0) { + aws_atomic_store_int(&s_server_client_tester.server_shutdown_invoked, 1); + if (!s_server_client_tester.window_update_after_shutdown) { + rw_handler_trigger_increment_read_window( + s_server_client_tester.client_args.rw_handler, s_server_client_tester.client_args.rw_slot, 100); + } + aws_channel_shutdown(s_server_client_tester.server_args.channel, AWS_OP_SUCCESS); + + aws_mutex_lock(&s_server_client_tester.server_mutex); + aws_condition_variable_wait_pred( + &s_server_client_tester.server_condition_variable, + &s_server_client_tester.server_mutex, + s_tls_channel_shutdown_predicate, + &s_server_client_tester.server_args); + aws_mutex_unlock(&s_server_client_tester.server_mutex); + } + aws_mutex_lock(client_rw_args->mutex); + + aws_byte_buf_write_from_whole_buffer(&client_rw_args->received_message, *data_read); + client_rw_args->read_invocations += 1; + client_rw_args->invocation_happened = true; + + aws_mutex_unlock(client_rw_args->mutex); + aws_condition_variable_notify_one(client_rw_args->condition_variable); + } else { + AWS_FATAL_ASSERT(false && "The channel has already shutdown before process the message."); + } + return client_rw_args->received_message; +} + +/** + * Test that when the socket initailize the shutdown process becasue of socket closed, we have a pending window update + * task to start the reading of the cached data in TLS handler. So, the channel will run the window update task and + * followed by a shutdown task immediately. + * + * Previously, the window update task will schedule read task if it opens the window back from close, but since the + * shutdown task already been scheluded, the read will happen after shutdown. So, it result in lost of data. + */ +static int s_tls_channel_shutdown_with_cache_test_helper(struct aws_allocator *allocator, bool after_shutdown) { + ASSERT_SUCCESS(s_tls_channel_server_client_tester_init(allocator)); + s_server_client_tester.window_update_after_shutdown = after_shutdown; + + struct aws_byte_buf read_tag = aws_byte_buf_from_c_str("I'm a little teapot."); + struct aws_byte_buf write_tag = aws_byte_buf_from_c_str("I'm a big teapot"); + /* Initialize the handler for client with small window, and shutdown the server */ + struct aws_channel_handler *client_rw_handler = rw_handler_new( + allocator, + s_on_client_recive_shutdown_with_cache_data, + s_tls_test_handle_write, + true, + write_tag.len / 2, + &s_server_client_tester.client_rw_args); + ASSERT_NOT_NULL(client_rw_handler); + + struct aws_channel_handler *server_rw_handler = rw_handler_new( + allocator, + s_tls_test_handle_read, + s_tls_test_handle_write, + true, + SIZE_MAX, + &s_server_client_tester.server_rw_args); + ASSERT_NOT_NULL(server_rw_handler); + + s_server_client_tester.server_args.rw_handler = server_rw_handler; + s_server_client_tester.client_args.rw_handler = client_rw_handler; + + g_aws_channel_max_fragment_size = 4096; + ASSERT_SUCCESS(s_set_socket_channel(&s_server_client_tester)); + + /* Server sends data to client */ + rw_handler_write( + s_server_client_tester.server_args.rw_handler, s_server_client_tester.server_args.rw_slot, &read_tag); ASSERT_SUCCESS(aws_mutex_lock(&c_tester.mutex)); ASSERT_SUCCESS(aws_condition_variable_wait_pred( - &c_tester.condition_variable, &c_tester.mutex, s_tls_channel_shutdown_predicate, &outgoing_args)); - aws_server_bootstrap_destroy_socket_listener(local_server_tester.server_bootstrap, local_server_tester.listener); + &c_tester.condition_variable, + &c_tester.mutex, + s_tls_test_read_predicate, + &s_server_client_tester.client_rw_args)); + ASSERT_SUCCESS(aws_mutex_unlock(&c_tester.mutex)); + + if (s_server_client_tester.window_update_after_shutdown) { + rw_handler_trigger_increment_read_window( + s_server_client_tester.client_args.rw_handler, s_server_client_tester.client_args.rw_slot, 100); + } + + /* Make sure client also shutdown without error. */ + ASSERT_SUCCESS(aws_mutex_lock(&c_tester.mutex)); ASSERT_SUCCESS(aws_condition_variable_wait_pred( - &c_tester.condition_variable, &c_tester.mutex, s_tls_listener_destroy_predicate, &incoming_args)); + &c_tester.condition_variable, + &c_tester.mutex, + s_tls_channel_shutdown_predicate, + &s_server_client_tester.client_args)); aws_mutex_unlock(&c_tester.mutex); + + s_server_client_tester.client_rw_args.invocation_happened = false; + + ASSERT_INT_EQUALS(2, s_server_client_tester.client_rw_args.read_invocations); + + ASSERT_BIN_ARRAYS_EQUALS( + read_tag.buffer, + read_tag.len, + s_server_client_tester.client_rw_args.received_message.buffer, + s_server_client_tester.client_rw_args.received_message.len); + /* clean up */ - ASSERT_SUCCESS(s_tls_opt_tester_clean_up(&client_tls_opt_tester)); - aws_client_bootstrap_release(client_bootstrap); - ASSERT_SUCCESS(s_tls_local_server_tester_clean_up(&local_server_tester)); - ASSERT_SUCCESS(s_tls_common_tester_clean_up(&c_tester)); + /*no shutdown on the client necessary here (it should have been triggered by shutting down the other side). just + * wait for the event to fire. */ + ASSERT_SUCCESS(s_tls_channel_server_client_tester_cleanup()); return AWS_OP_SUCCESS; } -AWS_TEST_CASE(tls_channel_echo_and_backpressure_test, s_tls_channel_echo_and_backpressure_test_fn) +static int s_tls_channel_shutdown_with_cache_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + return s_tls_channel_shutdown_with_cache_test_helper(allocator, false); +} + +AWS_TEST_CASE(tls_channel_shutdown_with_cache_test, s_tls_channel_shutdown_with_cache_test_fn) + +static int s_tls_channel_shutdown_with_cache_window_update_after_shutdown_test_fn( + struct aws_allocator *allocator, + void *ctx) { + (void)ctx; + return s_tls_channel_shutdown_with_cache_test_helper(allocator, true); +} +AWS_TEST_CASE( + tls_channel_shutdown_with_cache_window_update_after_shutdown_test, + s_tls_channel_shutdown_with_cache_window_update_after_shutdown_test_fn) struct default_host_callback_data { struct aws_host_address aaaa_address; @@ -768,8 +1004,6 @@ static int s_verify_negotiation_fails( uint32_t port, void (*context_options_override_fn)(struct aws_tls_ctx_options *)) { - aws_io_library_init(allocator); - ASSERT_SUCCESS(s_tls_common_tester_init(allocator, &c_tester)); struct aws_tls_ctx_options client_ctx_options; @@ -779,12 +1013,14 @@ static int s_verify_negotiation_fails( (*context_options_override_fn)(&client_ctx_options); } - ASSERT_SUCCESS(s_verify_negotiation_fails_helper(allocator, host_name, port, &client_ctx_options)); + int ret = s_verify_negotiation_fails_helper(allocator, host_name, port, &client_ctx_options); + if (ret == AWS_OP_SUCCESS) { + aws_tls_ctx_options_clean_up(&client_ctx_options); + ASSERT_SUCCESS(s_tls_common_tester_clean_up(&c_tester)); - aws_tls_ctx_options_clean_up(&client_ctx_options); - ASSERT_SUCCESS(s_tls_common_tester_clean_up(&c_tester)); - - return AWS_OP_SUCCESS; + return AWS_OP_SUCCESS; + } + return ret; } static int s_verify_negotiation_fails_with_ca_override( @@ -792,8 +1028,6 @@ static int s_verify_negotiation_fails_with_ca_override( const struct aws_string *host_name, const char *root_ca_path) { - aws_io_library_init(allocator); - ASSERT_SUCCESS(s_tls_common_tester_init(allocator, &c_tester)); struct aws_tls_ctx_options client_ctx_options; @@ -801,12 +1035,14 @@ static int s_verify_negotiation_fails_with_ca_override( ASSERT_SUCCESS(aws_tls_ctx_options_override_default_trust_store_from_path(&client_ctx_options, NULL, root_ca_path)); - ASSERT_SUCCESS(s_verify_negotiation_fails_helper(allocator, host_name, 443, &client_ctx_options)); - - ASSERT_SUCCESS(s_tls_common_tester_clean_up(&c_tester)); - aws_tls_ctx_options_clean_up(&client_ctx_options); + int ret = s_verify_negotiation_fails_helper(allocator, host_name, 443, &client_ctx_options); + if (ret == AWS_OP_SUCCESS) { + aws_tls_ctx_options_clean_up(&client_ctx_options); + ASSERT_SUCCESS(s_tls_common_tester_clean_up(&c_tester)); - return AWS_OP_SUCCESS; + return AWS_OP_SUCCESS; + } + return ret; } # if defined(USE_S2N) @@ -1031,8 +1267,6 @@ static int s_tls_client_channel_negotiation_error_socket_closed_fn(struct aws_al const char *host_name = "aws-crt-test-stuff.s3.amazonaws.com"; uint32_t port = 80; /* Note: intentionally wrong and not 443 */ - aws_io_library_init(allocator); - ASSERT_SUCCESS(s_tls_common_tester_init(allocator, &c_tester)); struct tls_opt_tester client_tls_opt_tester; @@ -1097,8 +1331,6 @@ static int s_verify_good_host( uint32_t port, void (*override_tls_options_fn)(struct aws_tls_ctx_options *)) { - aws_io_library_init(allocator); - ASSERT_SUCCESS(s_tls_common_tester_init(allocator, &c_tester)); struct tls_test_args outgoing_args = { @@ -1422,8 +1654,6 @@ static void s_reset_arg_state(struct tls_test_args *setup_test_args) { static int s_tls_server_multiple_connections_fn(struct aws_allocator *allocator, void *ctx) { (void)ctx; - aws_io_library_init(allocator); - ASSERT_SUCCESS(s_tls_common_tester_init(allocator, &c_tester)); struct tls_test_args outgoing_args; @@ -1571,8 +1801,6 @@ static void s_on_client_connected_do_hangup(struct aws_socket *socket, int error static int s_tls_server_hangup_during_negotiation_fn(struct aws_allocator *allocator, void *ctx) { (void)ctx; - aws_io_library_init(allocator); - ASSERT_SUCCESS(s_tls_common_tester_init(allocator, &c_tester)); struct tls_test_args outgoing_args; @@ -1837,8 +2065,6 @@ AWS_TEST_CASE(tls_channel_statistics_test, s_tls_channel_statistics_test) static int s_tls_certificate_chain_test(struct aws_allocator *allocator, void *ctx) { (void)ctx; - aws_io_library_init(allocator); - ASSERT_SUCCESS(s_tls_common_tester_init(allocator, &c_tester)); struct tls_test_args outgoing_args;