Skip to content

Commit

Permalink
Make _thread.ThreadHandle thread-safe in free-threaded builds
Browse files Browse the repository at this point in the history
We protect the mutable state of `ThreadHandle` using a `_PyOnceFlag`.
Concurrent operations (i.e. `join` or `detach`) on `ThreadHandle` block
until it is their turn to execute or an earlier operation succeeds.
Once an operation has been applied successfully all future operations
complete immediately.

The `join()` method is now idempotent. It may be called multiple times
but the underlying OS thread will only be joined once. After `join()`
succeeds, any future calls to `join()` will succeed immediately.

The `detach()` method is also idempotent. It may be called multiple times
but the underlying OS thread will only be detached once. After `detach()`
succeeds, any future calls to `detach()` will succeed immediately.

If the handle is being joined, `detach()` blocks until the join completes.
  • Loading branch information
mpage committed Feb 8, 2024
1 parent ef3ceab commit 1eeef32
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 42 deletions.
9 changes: 4 additions & 5 deletions Lib/test/test_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ def task():
with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
handle.join()
with self.assertRaisesRegex(ValueError, "not joinable"):
handle.join()
handle.join()

def test_joinable_not_joined(self):
handle_destroyed = thread.allocate_lock()
Expand Down Expand Up @@ -255,7 +254,7 @@ def task():
handles.append(handle)
start_joinable_thread_returned.release()
thread_detached.acquire()
with self.assertRaisesRegex(ValueError, "not joinable"):
with self.assertRaisesRegex(ValueError, "detached and thus cannot be joined"):
handle.join()

assert len(errors) == 0
Expand All @@ -272,7 +271,7 @@ def task():
# detach() returns even though the thread is blocked on lock
handle.detach()
# join() then cannot be called anymore
with self.assertRaisesRegex(ValueError, "not joinable"):
with self.assertRaisesRegex(ValueError, "detached and thus cannot be joined"):
handle.join()
lock.release()

Expand All @@ -283,7 +282,7 @@ def task():
with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
handle.join()
with self.assertRaisesRegex(ValueError, "not joinable"):
with self.assertRaisesRegex(ValueError, "joined and thus cannot be detached"):
handle.detach()


Expand Down
168 changes: 131 additions & 37 deletions Modules/_threadmodule.c
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

/* Thread module */
/* Interface to Sjoerd's portable C thread library */

#include "Python.h"
#include "pycore_interp.h" // _PyInterpreterState.threads.count
#include "pycore_lock.h"
#include "pycore_moduleobject.h" // _PyModule_GetState()
#include "pycore_modsupport.h" // _PyArg_NoKeywords()
#include "pycore_pylifecycle.h"
Expand Down Expand Up @@ -42,14 +42,49 @@ get_thread_state(PyObject *module)

// _ThreadHandle type

typedef enum {
THREAD_HANDLE_INVALID,
THREAD_HANDLE_JOINED,
THREAD_HANDLE_DETACHED,
} ThreadHandleState;

// A handle to join or detach an OS thread.
//
// Joining or detaching the handle is idempotent; the underlying OS thread is
// joined or detached only once. Concurrent operations block until it is their
// turn to execute or an operation completes successfully. Once an operation
// has completed successfully all future operations complete immediately.
typedef struct {
PyObject_HEAD
struct llist_node node; // linked list node (see _pythread_runtime_state)

// The `ident` and `handle` fields are immutable once the object is visible
// to threads other than its creator, thus they do not need to be accessed
// atomically.
PyThread_ident_t ident;
PyThread_handle_t handle;
char joinable;

// State is set once by the first successful `join` or `detach` operation
// (or if the handle is invalidated).
ThreadHandleState state;
_PyOnceFlag once;
} ThreadHandleObject;

// An operation on a ThreadHandle that sets the state and return 0 on success
// or returns -1 on error.
typedef int (*ThreadHandleOp)(ThreadHandleObject *);

// Maybe execute op on the handle. Returns -1 if op was called and returned -1.
static int
try_thread_handle_op(ThreadHandleOp op, ThreadHandleObject *handle)
{
if (_PyOnceFlag_CallOnce(&handle->once, (_Py_once_fn_t *)op, handle) ==
-1) {
return -1;
}
return handle->state;
}

static ThreadHandleObject*
new_thread_handle(thread_module_state* state)
{
Expand All @@ -59,7 +94,7 @@ new_thread_handle(thread_module_state* state)
}
self->ident = 0;
self->handle = 0;
self->joinable = 0;
self->once = (_PyOnceFlag){0};

HEAD_LOCK(&_PyRuntime);
llist_insert_tail(&_PyRuntime.threads.handles, &self->node);
Expand All @@ -68,6 +103,17 @@ new_thread_handle(thread_module_state* state)
return self;
}

static int
detach_thread(ThreadHandleObject *handle)
{
// This is typically short so no need to release the GIL
if (PyThread_detach_thread(handle->handle)) {
return -1;
}
handle->state = THREAD_HANDLE_DETACHED;
return 0;
}

static void
ThreadHandle_dealloc(ThreadHandleObject *self)
{
Expand All @@ -80,17 +126,30 @@ ThreadHandle_dealloc(ThreadHandleObject *self)
}
HEAD_UNLOCK(&_PyRuntime);

if (self->joinable) {
int ret = PyThread_detach_thread(self->handle);
if (ret) {
PyErr_SetString(ThreadError, "Failed detaching thread");
PyErr_WriteUnraisable(tp);
}
if (try_thread_handle_op(detach_thread, self) == -1) {
PyErr_SetString(ThreadError, "Failed detaching thread");
PyErr_WriteUnraisable(tp);
}
PyObject_Free(self);
Py_DECREF(tp);
}

static int
do_invalidate_thread_handle(ThreadHandleObject *handle)
{
handle->state = THREAD_HANDLE_INVALID;
return 0;
}

static void
invalidate_thread_handle(ThreadHandleObject *handle)
{
if (try_thread_handle_op(do_invalidate_thread_handle, handle) == -1) {
Py_FatalError("failed invalidating thread handle");
Py_UNREACHABLE();
}
}

void
_PyThread_AfterFork(struct _pythread_runtime_state *state)
{
Expand All @@ -108,7 +167,7 @@ _PyThread_AfterFork(struct _pythread_runtime_state *state)
}

// Disallow calls to detach() and join() as they could crash.
hobj->joinable = 0;
invalidate_thread_handle(hobj);
llist_remove(node);
}
}
Expand All @@ -126,49 +185,84 @@ ThreadHandle_get_ident(ThreadHandleObject *self, void *ignored)
return PyLong_FromUnsignedLongLong(self->ident);
}

static PyObject *
invalid_handle_error(void)
{
PyErr_SetString(PyExc_ValueError,
"the handle is invalid and thus cannot be detached");
return NULL;
}

static PyObject *
ThreadHandle_detach(ThreadHandleObject *self, void* ignored)
{
if (!self->joinable) {
PyErr_SetString(PyExc_ValueError,
"the thread is not joinable and thus cannot be detached");
return NULL;
switch (try_thread_handle_op(detach_thread, self)) {
case -1: {
PyErr_SetString(ThreadError, "Failed detaching thread");
return NULL;
}
case THREAD_HANDLE_INVALID: {
return invalid_handle_error();
}
case THREAD_HANDLE_JOINED: {
PyErr_SetString(
PyExc_ValueError,
"the thread has been joined and thus cannot be detached");
return NULL;
}
case THREAD_HANDLE_DETACHED: {
Py_RETURN_NONE;
}
default: {
Py_UNREACHABLE();
}
}
self->joinable = 0;
// This is typically short so no need to release the GIL
int ret = PyThread_detach_thread(self->handle);
if (ret) {
PyErr_SetString(ThreadError, "Failed detaching thread");
return NULL;
}

static int
join_thread(ThreadHandleObject *handle)
{
int err;
Py_BEGIN_ALLOW_THREADS
err = PyThread_join_thread(handle->handle);
Py_END_ALLOW_THREADS
if (err) {
return -1;
}
Py_RETURN_NONE;
handle->state = THREAD_HANDLE_JOINED;
return 0;
}

static PyObject *
ThreadHandle_join(ThreadHandleObject *self, void* ignored)
{
if (!self->joinable) {
PyErr_SetString(PyExc_ValueError, "the thread is not joinable");
return NULL;
}
if (self->ident == PyThread_get_thread_ident_ex()) {
// PyThread_join_thread() would deadlock or error out.
PyErr_SetString(ThreadError, "Cannot join current thread");
return NULL;
}
// Before actually joining, we must first mark the thread as non-joinable,
// as joining several times simultaneously or sequentially is undefined behavior.
self->joinable = 0;
int ret;
Py_BEGIN_ALLOW_THREADS
ret = PyThread_join_thread(self->handle);
Py_END_ALLOW_THREADS
if (ret) {
PyErr_SetString(ThreadError, "Failed joining thread");
return NULL;

switch (try_thread_handle_op(join_thread, self)) {
case -1: {
PyErr_SetString(ThreadError, "Failed joining thread");
return NULL;
}
case THREAD_HANDLE_INVALID: {
return invalid_handle_error();
}
case THREAD_HANDLE_JOINED: {
Py_RETURN_NONE;
}
case THREAD_HANDLE_DETACHED: {
PyErr_SetString(
PyExc_ValueError,
"the thread is detached and thus cannot be joined");
return NULL;
}
default: {
Py_UNREACHABLE();
}
}
Py_RETURN_NONE;
}

static PyGetSetDef ThreadHandle_getsetlist[] = {
Expand Down Expand Up @@ -1424,12 +1518,12 @@ thread_PyThread_start_joinable_thread(PyObject *module, PyObject *func)
}
if (do_start_new_thread(state, func, args, /*kwargs=*/ NULL, /*joinable=*/ 1,
&hobj->ident, &hobj->handle)) {
invalidate_thread_handle(hobj);
Py_DECREF(args);
Py_DECREF(hobj);
return NULL;
}
Py_DECREF(args);
hobj->joinable = 1;
return (PyObject*) hobj;
}

Expand Down

0 comments on commit 1eeef32

Please sign in to comment.