Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
4306: Clear the td states on ECALL return r=mingweishih a=mingweishih

Ensure the td states are cleared on ECALL return so the the following ECALL using the same tcs will not inherit the states.

Signed-off-by: Ming-Wei Shih <[email protected]>

Co-authored-by: Ming-Wei Shih <[email protected]>
  • Loading branch information
oeciteam and mingweishih committed Nov 17, 2021
2 parents 99b21cb + 68acf9d commit 264c967
Show file tree
Hide file tree
Showing 13 changed files with 236 additions and 54 deletions.
4 changes: 3 additions & 1 deletion enclave/core/sgx/calls.c
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,9 @@ void oe_abort(void)
{
oe_sgx_td_t* td = oe_sgx_get_td();

td->state = OE_TD_STATE_ABORTED;
/* only update the state if td is initialized */
if (td)
td->state = OE_TD_STATE_ABORTED;

// Once it starts to crash, the state can only transit forward, not
// backward.
Expand Down
62 changes: 51 additions & 11 deletions enclave/core/sgx/td.c
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,55 @@ oe_sgx_td_t* oe_sgx_get_td()
/*
**==============================================================================
**
** oe_sgx_set_td_exception_handler_stack()
** oe_sgx_clear_td_states()
**
** Internal API that allows an enclave to clear the td states.
**
**==============================================================================
*/

void oe_sgx_td_clear_states(oe_sgx_td_t* td)
{
/* Mask host signals by default */
oe_sgx_td_mask_host_signal(td);

/* Clear exception-related information */
td->exception_code = 0;
td->exception_flags = 0;
td->exception_address = 0;
td->faulting_address = 0;
td->error_code = 0;
td->last_ssa_rsp = 0;
td->last_ssa_rbp = 0;

/* Clear states related host signal handling */
td->exception_nesting_level = 0;
td->is_handling_host_signal = 0;
td->host_signal_bitmask = 0;
td->host_signal = 0;

/* Clear the states of the state machine */
td->previous_state = OE_TD_STATE_NULL;
td->state = OE_TD_STATE_RUNNING; // the default state during the runtime
}

/*
**==============================================================================
**
** oe_sgx_td_set_exception_handler_stack()
**
** Internal API that allows an enclave to setup stack area for
** exception handlers to use.
**
**==============================================================================
*/
bool oe_sgx_set_td_exception_handler_stack(void* stack, uint64_t size)
bool oe_sgx_td_set_exception_handler_stack(
oe_sgx_td_t* td,
void* stack,
uint64_t size)
{
oe_sgx_td_t* td = oe_sgx_get_td();
if (!td)
return false;

/* ensure stack + size is 16-byte aligned */
if (((uint64_t)stack + size) % 16)
Expand All @@ -204,21 +243,22 @@ bool oe_sgx_set_td_exception_handler_stack(void* stack, uint64_t size)
**==============================================================================
*/

OE_INLINE void _set_td_host_signal_unmasked(uint64_t value)
OE_INLINE void _set_td_host_signal_unmasked(oe_sgx_td_t* td, uint64_t value)
{
oe_sgx_td_t* td = oe_sgx_get_td();
if (!td)
return;

td->host_signal_unmasked = value;
}

void oe_sgx_td_mask_host_signal()
void oe_sgx_td_mask_host_signal(oe_sgx_td_t* td)
{
_set_td_host_signal_unmasked(0);
_set_td_host_signal_unmasked(td, 0);
}

void oe_sgx_td_unmask_host_signal()
void oe_sgx_td_unmask_host_signal(oe_sgx_td_t* td)
{
_set_td_host_signal_unmasked(1);
_set_td_host_signal_unmasked(td, 1);
}

/*
Expand Down Expand Up @@ -257,12 +297,12 @@ OE_INLINE bool _set_td_host_signal_bitmask(
return true;
}

bool oe_sgx_register_td_host_signal(oe_sgx_td_t* td, int signal_number)
bool oe_sgx_td_register_host_signal(oe_sgx_td_t* td, int signal_number)
{
return _set_td_host_signal_bitmask(td, signal_number, 1 /* set */);
}

bool oe_sgx_unregister_td_host_signal(oe_sgx_td_t* td, int signal_number)
bool oe_sgx_td_unregister_host_signal(oe_sgx_td_t* td, int signal_number)
{
return _set_td_host_signal_bitmask(td, signal_number, 0 /* clear */);
}
Expand Down
6 changes: 6 additions & 0 deletions enclave/core/sgx/td_basic.c
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ void td_init(oe_sgx_td_t* td)
/* List of callsites is initially empty */
td->callsites = NULL;

/* Set the exception handler stack to NULL */
oe_sgx_td_set_exception_handler_stack(td, NULL, 0);

oe_thread_local_init(td);
}
}
Expand Down Expand Up @@ -152,5 +155,8 @@ void td_clear(oe_sgx_td_t* td)
/* Clear the magic number */
td->magic = 0;

/* Clear td states */
oe_sgx_td_clear_states(td);

/* Never clear oe_sgx_td_t.initialized nor host registers */
}
18 changes: 9 additions & 9 deletions include/openenclave/internal/sgx/td.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,20 +227,20 @@ OE_STATIC_ASSERT(
/* Get the thread data object for the current thread */
oe_sgx_td_t* oe_sgx_get_td(void);

/* The following APIs are expected to be used only by the thread itself. */
void oe_sgx_td_clear_states(oe_sgx_td_t* td);

bool oe_sgx_set_td_exception_handler_stack(void* stack, uint64_t size);
bool oe_sgx_td_set_exception_handler_stack(
oe_sgx_td_t* td,
void* stack,
uint64_t size);

void oe_sgx_td_mask_host_signal();
void oe_sgx_td_mask_host_signal(oe_sgx_td_t* td);

void oe_sgx_td_unmask_host_signal();
void oe_sgx_td_unmask_host_signal(oe_sgx_td_t* td);

/* The following APIs are expected to be used by both the thread itself
* and other threads. */
bool oe_sgx_td_register_host_signal(oe_sgx_td_t* td, int signal_number);

bool oe_sgx_register_td_host_signal(oe_sgx_td_t* td, int signal_number);

bool oe_sgx_unregister_td_host_signal(oe_sgx_td_t* td, int signal_number);
bool oe_sgx_td_unregister_host_signal(oe_sgx_td_t* td, int signal_number);

bool oe_sgx_td_host_signal_registered(oe_sgx_td_t* td, int signal_number);

Expand Down
9 changes: 6 additions & 3 deletions tests/VectorException/enc/enc.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,19 @@ int initialize_exception_handler_stack(
uint64_t* stack_size,
int use_exception_handler_stack)
{
oe_sgx_td_t* td = oe_sgx_get_td();

if (use_exception_handler_stack)
{
*stack_size = EXCEPTION_HANDLER_STACK_SIZE;
*stack = malloc(*stack_size);
if (!*stack)
return -1;
if (!oe_sgx_set_td_exception_handler_stack(*stack, *stack_size))
if (!oe_sgx_td_set_exception_handler_stack(td, *stack, *stack_size))
return -1;
}
else
{
oe_sgx_td_t* td = oe_sgx_get_td();
void* tcs = td_to_tcs(td);
*stack_size = STACK_SIZE;
*stack = (void*)((uint64_t)tcs - PAGE_SIZE - STACK_SIZE);
Expand All @@ -48,9 +49,11 @@ void cleaup_exception_handler_stack(
uint64_t* stack_size,
int use_exception_handler_stack)
{
oe_sgx_td_t* td = oe_sgx_get_td();

if (use_exception_handler_stack)
{
oe_sgx_set_td_exception_handler_stack(NULL, 0);
oe_sgx_td_set_exception_handler_stack(td, NULL, 0);
free(*stack);
}

Expand Down
6 changes: 4 additions & 2 deletions tests/VectorException/enc/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@ __attribute__((constructor)) void test_cpuid_constructor()
void* stack = malloc(EXCEPTION_HANDLER_STACK_SIZE);
if (!stack)
return;
oe_sgx_set_td_exception_handler_stack(stack, EXCEPTION_HANDLER_STACK_SIZE);
oe_sgx_td_t* td = oe_sgx_get_td();
oe_sgx_td_set_exception_handler_stack(
td, stack, EXCEPTION_HANDLER_STACK_SIZE);
test_cpuid_instruction(500, 1);
test_cpuid_instruction(600, 1);
test_cpuid_instruction(AESNI_INSTRUCTIONS, 1);

oe_sgx_set_td_exception_handler_stack(NULL, 0);
oe_sgx_td_set_exception_handler_stack(td, NULL, 0);
free(stack);
}

Expand Down
95 changes: 87 additions & 8 deletions tests/sgx/td_state/enc/enc.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <openenclave/corelibc/string.h>
#include <openenclave/enclave.h>
#include <openenclave/internal/jump.h>
#include <openenclave/internal/print.h>
#include <openenclave/internal/raise.h>
#include <openenclave/internal/sgx/td.h>
Expand All @@ -11,6 +12,7 @@

#include <signal.h>
#include <stdio.h>
#include <stdlib.h>

#define OE_EXPECT(a, b) \
do \
Expand All @@ -30,11 +32,7 @@
} \
} while (0)

bool oe_sgx_register_target_td_host_signal(
oe_sgx_td_t* target_td,
int signal_number);

typedef struct _thread_info_nonblocking_t
typedef struct _thread_info_t
{
int tid;
oe_sgx_td_t* td;
Expand All @@ -44,6 +42,9 @@ static thread_info_t _thread_info;
static volatile int _handler_done;
static volatile int* _host_lock_state;

static thread_info_t _thread_handler_no_return_info;
static oe_jmpbuf_t jump_buffer;

static void cpuid(
unsigned int leaf,
unsigned int subleaf,
Expand Down Expand Up @@ -163,7 +164,7 @@ static uint64_t td_state_handler(oe_exception_record_t* exception_record)
OE_TEST(oe_sgx_td_is_handling_host_signal(_thread_info.td));

OE_TEST(
oe_sgx_unregister_td_host_signal(_thread_info.td, SIGUSR1) == true);
oe_sgx_td_unregister_host_signal(_thread_info.td, SIGUSR1) == true);

__atomic_store_n(_host_lock_state, 2, __ATOMIC_RELEASE);

Expand Down Expand Up @@ -221,7 +222,7 @@ void enc_run_thread(int tid)
OE_CHECK(oe_add_vectored_exception_handler(false, td_state_handler));

// Invoke the internal API to unmask host signals
oe_sgx_td_unmask_host_signal();
oe_sgx_td_unmask_host_signal(_thread_info.td);

// Ensure the order of setting the lock
asm volatile("" ::: "memory");
Expand Down Expand Up @@ -273,6 +274,8 @@ void enc_run_thread(int tid)
// Expect the state to be persisted after an exception.
OE_EXPECT(_thread_info.td->state, OE_TD_STATE_RUNNING);

OE_CHECK(oe_remove_vectored_exception_handler(td_state_handler));

printf("(tid=%d) thread is exiting...\n", self_tid);
done:
return;
Expand Down Expand Up @@ -308,7 +311,7 @@ void enc_td_state(uint64_t lock_state)
OE_TEST(_thread_info.tid != 0);
host_sleep_msec(30);

OE_TEST(oe_sgx_register_td_host_signal(_thread_info.td, SIGUSR1) == true);
OE_TEST(oe_sgx_td_register_host_signal(_thread_info.td, SIGUSR1) == true);

printf(
"(tid=%d) Sending interrupt to (td=0x%lx, tid=%d) inside the "
Expand Down Expand Up @@ -346,6 +349,82 @@ void enc_td_state(uint64_t lock_state)
OE_EXPECT(_thread_info.td->state, OE_TD_STATE_EXITED);
}

static uint64_t td_state_handler_no_return(
oe_exception_record_t* exception_record)
{
if (exception_record->code == OE_EXCEPTION_DIVIDE_BY_ZERO)
{
oe_longjmp(&jump_buffer, 1);
}

return OE_EXCEPTION_ABORT_EXECUTION;
}

void enc_run_thread_handler_no_return(int tid)
{
oe_result_t result = OE_OK;

_thread_handler_no_return_info.tid = tid;
_thread_handler_no_return_info.td = oe_sgx_get_td();

printf(
"(tid=%d) thread is created td=0x%lx\n",
_thread_handler_no_return_info.tid,
(uint64_t)_thread_handler_no_return_info.td);

OE_CHECK(
oe_add_vectored_exception_handler(false, td_state_handler_no_return));

if (oe_setjmp(&jump_buffer) == 0)
divide_by_zero_exception_function();

// Expect the state is still OE_TD_STATE_SECOND_LEVEL_EXCEPTION_HANDLING
// (the handler does not return)
OE_EXPECT(
_thread_handler_no_return_info.td->state,
OE_TD_STATE_SECOND_LEVEL_EXCEPTION_HANDLING);

done:
return;
}

void enc_run_thread_reuse_tcs(int tid)
{
oe_sgx_td_t* td = oe_sgx_get_td();

// Expect the tcs is re-used
OE_EXPECT(td, _thread_handler_no_return_info.td);

OE_EXPECT(_thread_handler_no_return_info.td->state, OE_TD_STATE_RUNNING);

printf("(tid=%d) thread is created td=0x%lx\n", tid, (uint64_t)td);
}

void enc_td_state_handler_no_return()
{
oe_result_t result;
int tid = 0;

host_get_tid(&tid);
OE_TEST(tid != 0);

printf("(tid=%d) Create a thread...\n", tid);

result = host_create_thread_handler_no_return();
if (result != OE_OK)
return;

host_join_thread();

printf("(tid=%d) Create a thread...\n", tid);

result = host_create_thread_reuse_tcs();
if (result != OE_OK)
return;

host_join_thread();
}

OE_SET_ENCLAVE_SGX(
1, /* ProductID */
1, /* SecurityVersion */
Expand Down
Loading

0 comments on commit 264c967

Please sign in to comment.