Skip to content

Commit

Permalink
Fix websocket shutdown behavior (#483)
Browse files Browse the repository at this point in the history
The bug was introduced in [PR #474](https://github.com/awslabs/aws-c-http/pull/474/files#diff-ee776c7576cfff50a64158d59a6173ab9a0aa373150574aa9987b4f8726b58e3)
  - `is_writing_stopped = true` shouldn't be set directly, there's a helper function `s_stop_writing()` that ensures subsequent calls to `aws_websocket_send_frame()` will fail.

Let's take a whole new approach these channel-shutdown-window-deadlock issues:
- add `s_stop_reading_and_dont_block_shutdown()` function that sets `is_reading_stopped = true`, but also increments the read window so that channel shutdown won't deadlock.
    - Most places that were setting `is_reading_stopped = true` now use this helper instead
- Revamp how `aws_channel_shutdown()` is called. Lots of channel behavior has changed since [this websocket code was written](#48).
  - If on the channel-thread, just call `aws_channel_shutdown()`
      - now that [aws_channel_shutdown()](awslabs/aws-c-io#172) is always async, we don't need to defensively schedule a task to call it
  - If off-thread, use `s_schedule_channel_shutdown_from_offthead()`
      - now that this is only called from `aws_websocket_close()`, or when the refcount goes to zero, we can assume the user is OK if reading stops, and it can call `s_stop_reading_and_dont_block_shutdown()` on the way to shutting down.
- Add the test to verify that send after close should fail

Co-authored-by: Michael Graeb <[email protected]>
  • Loading branch information
TingDaoK and graebm authored Aug 16, 2024
1 parent 7db2452 commit 4e74ab1
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 64 deletions.
125 changes: 61 additions & 64 deletions source/websocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct aws_websocket {
aws_websocket_on_incoming_frame_complete_fn *on_incoming_frame_complete;

struct aws_channel_task move_synced_data_to_thread_task;
struct aws_channel_task shutdown_channel_task;
struct aws_channel_task shutdown_channel_from_offthread_task;
struct aws_channel_task increment_read_window_task;
struct aws_channel_task waiting_on_payload_stream_task;
struct aws_channel_task close_timeout_task;
Expand Down Expand Up @@ -85,7 +85,10 @@ struct aws_websocket {
/* True when no more frames will be read, due to:
* - a CLOSE frame was received
* - decoder error
* - channel shutdown in read-dir */
* - channel shutdown in read-dir
* - user calling aws_websocket_close()
* - user dropping the last refcount
*/
bool is_reading_stopped;

/* True when no more frames will be written, due to:
Expand Down Expand Up @@ -124,9 +127,9 @@ struct aws_websocket {
/* Error-code returned by aws_websocket_send_frame() when is_writing_stopped is true */
int send_frame_error_code;

/* Use a task to issue a channel shutdown. */
int shutdown_channel_task_error_code;
bool is_shutdown_channel_task_scheduled;
/* Use a task to issue a channel shutdown from off-thread. */
int shutdown_channel_from_offthread_task_error_code;
bool is_shutdown_channel_from_offthread_task_scheduled;

bool is_move_synced_data_to_thread_task_scheduled;

Expand Down Expand Up @@ -186,10 +189,13 @@ static bool s_midchannel_send_payload(struct aws_websocket *websocket, struct aw
static void s_midchannel_send_complete(struct aws_websocket *websocket, int error_code, void *user_data);
static void s_move_synced_data_to_thread_task(struct aws_channel_task *task, void *arg, enum aws_task_status status);
static void s_increment_read_window_task(struct aws_channel_task *task, void *arg, enum aws_task_status status);
static void s_shutdown_channel_task(struct aws_channel_task *task, void *arg, enum aws_task_status status);
static void s_shutdown_channel_from_offthread_task(
struct aws_channel_task *task,
void *arg,
enum aws_task_status status);
static void s_waiting_on_payload_stream_task(struct aws_channel_task *task, void *arg, enum aws_task_status status);
static void s_close_timeout_task(struct aws_channel_task *task, void *arg, enum aws_task_status status);
static void s_schedule_channel_shutdown(struct aws_websocket *websocket, int error_code);
static void s_schedule_channel_shutdown_from_offthread(struct aws_websocket *websocket, int error_code);
static void s_shutdown_due_to_write_err(struct aws_websocket *websocket, int error_code);
static void s_shutdown_due_to_read_err(struct aws_websocket *websocket, int error_code);
static void s_stop_writing(struct aws_websocket *websocket, int send_frame_error_code);
Expand Down Expand Up @@ -285,7 +291,10 @@ struct aws_websocket *aws_websocket_handler_new(const struct aws_websocket_handl
websocket,
"websocket_move_synced_data_to_thread");
aws_channel_task_init(
&websocket->shutdown_channel_task, s_shutdown_channel_task, websocket, "websocket_shutdown_channel");
&websocket->shutdown_channel_from_offthread_task,
s_shutdown_channel_from_offthread_task,
websocket,
"websocket_shutdown_channel");
aws_channel_task_init(
&websocket->increment_read_window_task,
s_increment_read_window_task,
Expand Down Expand Up @@ -377,7 +386,7 @@ static void s_websocket_on_refcount_zero(void *user_data) {
AWS_LS_HTTP_WEBSOCKET, "id=%p: Websocket ref-count is zero, shut down if necessary.", (void *)websocket);

/* Channel might already be shut down, but make sure */
s_schedule_channel_shutdown(websocket, AWS_ERROR_SUCCESS);
s_schedule_channel_shutdown_from_offthread(websocket, AWS_ERROR_SUCCESS);

/* Channel won't destroy its slots/handlers until its refcount reaches 0 */
aws_channel_release_hold(websocket->channel_slot->channel);
Expand Down Expand Up @@ -897,6 +906,21 @@ static void s_complete_frame_list(struct aws_websocket *websocket, struct aws_li
aws_linked_list_init(frames);
}

/* Set is_reading_stopped = true, all further read data will be ignored.
* But also increment the read window, so that channel shutdown won't deadlock
* due to pending read-data in an upstream handler or the underlying OS socket. */
static void s_stop_reading_and_dont_block_shutdown(struct aws_websocket *websocket) {
AWS_ASSERT(aws_channel_thread_is_callers_thread(websocket->channel_slot->channel));
if (websocket->thread_data.is_reading_stopped) {
return;
}

AWS_LOGF_TRACE(AWS_LS_HTTP_WEBSOCKET, "id=%p: Websocket will ignore any further read data.", (void *)websocket);
websocket->thread_data.is_reading_stopped = true;

aws_channel_slot_increment_read_window(websocket->channel_slot, SIZE_MAX);
}

static void s_stop_writing(struct aws_websocket *websocket, int send_frame_error_code) {
AWS_ASSERT(aws_channel_thread_is_callers_thread(websocket->channel_slot->channel));
AWS_ASSERT(send_frame_error_code != AWS_ERROR_SUCCESS);
Expand Down Expand Up @@ -947,7 +971,7 @@ static void s_shutdown_due_to_write_err(struct aws_websocket *websocket, int err
(void *)websocket,
error_code,
aws_error_name(error_code));
s_schedule_channel_shutdown(websocket, error_code);
aws_channel_shutdown(websocket->channel_slot->channel, error_code);
}
}

Expand All @@ -961,18 +985,22 @@ static void s_shutdown_due_to_read_err(struct aws_websocket *websocket, int erro
error_code,
aws_error_name(error_code));

websocket->thread_data.is_reading_stopped = true;
s_stop_reading_and_dont_block_shutdown(websocket);

/* If there's a current incoming frame, complete it with the specific error code. */
if (websocket->thread_data.current_incoming_frame) {
s_complete_incoming_frame(websocket, error_code, NULL);
}

/* Tell channel to shutdown (it's ok to call this redundantly) */
s_schedule_channel_shutdown(websocket, error_code);
aws_channel_shutdown(websocket->channel_slot->channel, error_code);
}

static void s_shutdown_channel_task(struct aws_channel_task *task, void *arg, enum aws_task_status status) {
static void s_shutdown_channel_from_offthread_task(
struct aws_channel_task *task,
void *arg,
enum aws_task_status status) {

(void)task;

if (status != AWS_TASK_STATUS_RUN_READY) {
Expand All @@ -985,39 +1013,39 @@ static void s_shutdown_channel_task(struct aws_channel_task *task, void *arg, en
/* BEGIN CRITICAL SECTION */
s_lock_synced_data(websocket);

error_code = websocket->synced_data.shutdown_channel_task_error_code;
error_code = websocket->synced_data.shutdown_channel_from_offthread_task_error_code;

s_unlock_synced_data(websocket);
/* END CRITICAL SECTION */
websocket->thread_data.is_reading_stopped = true;
websocket->thread_data.is_writing_stopped = true;

/* Stop reading, so that shutdown won't be blocked.
* If something off-thread is causing shutdown (aws_websocket_close(), refcount 0, etc),
* the user may never interact with the websocket again. We can't rely on them
* to keep the window open and prevent deadlock during shutdown. */
s_stop_reading_and_dont_block_shutdown(websocket);

aws_channel_shutdown(websocket->channel_slot->channel, error_code);
/* Increase the window size after shutdown starts, to prevent deadlock when data still pending in the upstream
* handler. */
aws_channel_slot_increment_read_window(websocket->channel_slot, SIZE_MAX);
}

/* Tell the channel to shut down. It is safe to call this multiple times.
* The call to aws_channel_shutdown() is delayed so that a user invoking aws_websocket_close doesn't
* have completion callbacks firing before the function call even returns */
static void s_schedule_channel_shutdown(struct aws_websocket *websocket, int error_code) {
/* Tell the channel to shut down, from off-thread. It is safe to call this multiple times. */
static void s_schedule_channel_shutdown_from_offthread(struct aws_websocket *websocket, int error_code) {
bool schedule_shutdown = false;

/* BEGIN CRITICAL SECTION */
s_lock_synced_data(websocket);

if (!websocket->synced_data.is_shutdown_channel_task_scheduled) {
if (!websocket->synced_data.is_shutdown_channel_from_offthread_task_scheduled) {
schedule_shutdown = true;
websocket->synced_data.is_shutdown_channel_task_scheduled = true;
websocket->synced_data.shutdown_channel_task_error_code = error_code;
websocket->synced_data.is_shutdown_channel_from_offthread_task_scheduled = true;
websocket->synced_data.shutdown_channel_from_offthread_task_error_code = error_code;
}

s_unlock_synced_data(websocket);
/* END CRITICAL SECTION */

if (schedule_shutdown) {
aws_channel_schedule_task_now(websocket->channel_slot->channel, &websocket->shutdown_channel_task);
aws_channel_schedule_task_now(
websocket->channel_slot->channel, &websocket->shutdown_channel_from_offthread_task);
}
}

Expand All @@ -1038,14 +1066,13 @@ void aws_websocket_close(struct aws_websocket *websocket, bool free_scarce_resou
return;
}

/* TODO: aws_channel_shutdown() should let users specify error_code and "immediate" as separate parameters.
* Currently, any non-zero error_code results in "immediate" shutdown */
/* TODO: aws_channel_shutdown() should let users specify error_code and "immediate" as separate parameters. */
int error_code = AWS_ERROR_SUCCESS;
if (free_scarce_resources_immediately) {
error_code = AWS_ERROR_HTTP_CONNECTION_CLOSED;
}

s_schedule_channel_shutdown(websocket, error_code);
s_schedule_channel_shutdown_from_offthread(websocket, error_code);
}

static int s_handler_shutdown(
Expand Down Expand Up @@ -1255,17 +1282,7 @@ static int s_handler_process_read_message(
}

if (websocket->thread_data.incoming_message_window_update > 0) {
err = aws_channel_slot_increment_read_window(slot, websocket->thread_data.incoming_message_window_update);
if (err) {
AWS_LOGF_ERROR(
AWS_LS_HTTP_WEBSOCKET,
"id=%p: Failed to increment read window after message processing, error %d (%s). Closing "
"connection.",
(void *)websocket,
aws_last_error(),
aws_error_name(aws_last_error()));
goto error;
}
aws_channel_slot_increment_read_window(slot, websocket->thread_data.incoming_message_window_update);
}

goto clean_up;
Expand Down Expand Up @@ -1508,7 +1525,7 @@ static void s_complete_incoming_frame(struct aws_websocket *websocket, int error
AWS_LS_HTTP_WEBSOCKET,
"id=%p: Close frame received, any further data received will be ignored.",
(void *)websocket);
websocket->thread_data.is_reading_stopped = true;
s_stop_reading_and_dont_block_shutdown(websocket);

/* TODO: auto-close if there's a channel-handler to the right */

Expand Down Expand Up @@ -1598,37 +1615,17 @@ static int s_handler_increment_read_window(
}

if (increment != 0) {
int err = aws_channel_slot_increment_read_window(slot, increment);
if (err) {
goto error;
}
aws_channel_slot_increment_read_window(slot, increment);
}

return AWS_OP_SUCCESS;

error:
websocket->thread_data.is_reading_stopped = true;
/* Shutting down channel because I know that no one ever checks these errors */
s_shutdown_due_to_read_err(websocket, aws_last_error());
return AWS_OP_ERR;
}

static void s_increment_read_window_action(struct aws_websocket *websocket, size_t size) {
AWS_ASSERT(aws_channel_thread_is_callers_thread(websocket->channel_slot->channel));

int err = aws_channel_slot_increment_read_window(websocket->channel_slot, size);
if (err) {
AWS_LOGF_ERROR(
AWS_LS_HTTP_WEBSOCKET,
"id=%p: Failed to increment read window, error %d (%s). Closing websocket.",
(void *)websocket,
aws_last_error(),
aws_error_name(aws_last_error()));

s_schedule_channel_shutdown(websocket, aws_last_error());
}
}

static void s_increment_read_window_task(struct aws_channel_task *task, void *arg, enum aws_task_status status) {
(void)task;

Expand All @@ -1651,7 +1648,7 @@ static void s_increment_read_window_task(struct aws_channel_task *task, void *ar
AWS_LOGF_TRACE(
AWS_LS_HTTP_WEBSOCKET, "id=%p: Running task to increment read window by %zu.", (void *)websocket, size);

s_increment_read_window_action(websocket, size);
aws_channel_slot_increment_read_window(websocket->channel_slot, size);
}

void aws_websocket_increment_read_window(struct aws_websocket *websocket, size_t size) {
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ add_test_case(websocket_handler_window_manual_increment)
add_test_case(websocket_handler_window_manual_increment_off_thread)
add_test_case(websocket_handler_sends_pong_automatically)
add_test_case(websocket_handler_wont_send_pong_after_close_frame)
add_test_case(websocket_handler_send_frame_fails_if_websocket_closed)
add_test_case(websocket_midchannel_sanity_check)
add_test_case(websocket_midchannel_write_message)
add_test_case(websocket_midchannel_write_multiple_messages)
Expand Down
32 changes: 32 additions & 0 deletions tests/test_websocket_handler.c
Original file line number Diff line number Diff line change
Expand Up @@ -1883,6 +1883,38 @@ TEST_CASE(websocket_handler_wont_send_pong_after_close_frame) {
return AWS_OP_SUCCESS;
}

/* This is a regression test. If aws_websocket_close() leads to shutdown,
* then subsequent calls to aws_websocket_send_frame() should fail. */
TEST_CASE(websocket_handler_send_frame_fails_if_websocket_closed) {
(void)ctx;
(void)ctx;
struct tester tester;
ASSERT_SUCCESS(s_tester_init(&tester, allocator));

/* Call aws_websocket_close() and wait for shutdown to complete */
testing_channel_set_is_on_users_thread(&tester.testing_channel, false);
aws_websocket_close(tester.websocket, false);
testing_channel_set_is_on_users_thread(&tester.testing_channel, true);

ASSERT_SUCCESS(s_drain_written_messages(&tester));
ASSERT_TRUE(testing_channel_is_shutdown_completed(&tester.testing_channel));

/* aws_websocket_send_frame() should fail */
struct aws_byte_cursor payload = aws_byte_cursor_from_c_str("bitter butter.");
struct send_tester send = {
.payload = payload,
.def =
{
.opcode = AWS_WEBSOCKET_OPCODE_PING,
.fin = true,
},
};
ASSERT_FAILS(s_send_frame(&tester, &send));
ASSERT_UINT_EQUALS(AWS_ERROR_HTTP_WEBSOCKET_CLOSE_FRAME_SENT, aws_last_error());
ASSERT_SUCCESS(s_tester_clean_up(&tester));
return AWS_OP_SUCCESS;
}

TEST_CASE(websocket_midchannel_read_message) {
(void)ctx;
struct tester tester;
Expand Down

0 comments on commit 4e74ab1

Please sign in to comment.