Skip to content

Commit

Permalink
Make _thread.ThreadHandle thread-safe in free-threaded builds
Browse files Browse the repository at this point in the history
We protect the mutable state of `ThreadHandle` using a `_PyOnceFlag`.
Concurrent operations (i.e. `join` or `detach`) on `ThreadHandle` block
until it is their turn to execute or an earlier operation succeeds.
Once an operation has been applied successfully all future operations
complete immediately.

The `join()` method is now idempotent. It may be called multiple times
but the underlying OS thread will only be joined once. After `join()`
succeeds, any future calls to `join()` will succeed immediately.

The `detach()` method is also idempotent. It may be called multiple times
but the underlying OS thread will only be detached once. After `detach()`
succeeds, any future calls to `detach()` will succeed immediately.

If the handle is being joined, `detach()` blocks until the join completes.
  • Loading branch information
mpage committed Feb 8, 2024
1 parent ef3ceab commit cf6491c
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 58 deletions.
20 changes: 15 additions & 5 deletions Lib/test/test_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def task():
with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
handle.join()
with self.assertRaisesRegex(ValueError, "not joinable"):
handle.join()
# Subsequent join() calls should succeed
handle.join()

def test_joinable_not_joined(self):
handle_destroyed = thread.allocate_lock()
Expand Down Expand Up @@ -255,7 +255,7 @@ def task():
handles.append(handle)
start_joinable_thread_returned.release()
thread_detached.acquire()
with self.assertRaisesRegex(ValueError, "not joinable"):
with self.assertRaisesRegex(ValueError, "detached and thus cannot be joined"):
handle.join()

assert len(errors) == 0
Expand All @@ -272,7 +272,7 @@ def task():
# detach() returns even though the thread is blocked on lock
handle.detach()
# join() then cannot be called anymore
with self.assertRaisesRegex(ValueError, "not joinable"):
with self.assertRaisesRegex(ValueError, "detached and thus cannot be joined"):
handle.join()
lock.release()

Expand All @@ -283,9 +283,19 @@ def task():
with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
handle.join()
with self.assertRaisesRegex(ValueError, "not joinable"):
with self.assertRaisesRegex(ValueError, "joined and thus cannot be detached"):
handle.detach()

def test_detach_then_detach(self):
def task():
pass

with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
handle.detach()
# Subsequent calls to detach() should succeed
handle.detach()


class Barrier:
def __init__(self, num_threads):
Expand Down
21 changes: 5 additions & 16 deletions Lib/threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,14 +956,11 @@ def _after_fork(self, new_ident=None):
if self._tstate_lock is not None:
self._tstate_lock._at_fork_reinit()
self._tstate_lock.acquire()
if self._join_lock is not None:
self._join_lock._at_fork_reinit()
else:
# This thread isn't alive after fork: it doesn't have a tstate
# anymore.
self._is_stopped = True
self._tstate_lock = None
self._join_lock = None
self._handle = None

def __repr__(self):
Expand Down Expand Up @@ -996,8 +993,6 @@ def start(self):
if self._started.is_set():
raise RuntimeError("threads can only be started once")

self._join_lock = _allocate_lock()

with _active_limbo_lock:
_limbo[self] = self
try:
Expand Down Expand Up @@ -1167,17 +1162,7 @@ def join(self, timeout=None):
self._join_os_thread()

def _join_os_thread(self):
join_lock = self._join_lock
if join_lock is None:
return
with join_lock:
# Calling join() multiple times would raise an exception
# in one of the callers.
if self._handle is not None:
self._handle.join()
self._handle = None
# No need to keep this around
self._join_lock = None
self._handle.join()

def _wait_for_tstate_lock(self, block=True, timeout=-1):
# Issue #18808: wait for the thread state to be gone.
Expand Down Expand Up @@ -1478,6 +1463,10 @@ def __init__(self):
with _active_limbo_lock:
_active[self._ident] = self

def _join_os_thread(self):
# No ThreadHandle for main thread
pass


# Helper thread-local instance to detect when a _DummyThread
# is collected. Not a part of the public API.
Expand Down
162 changes: 125 additions & 37 deletions Modules/_threadmodule.c
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

/* Thread module */
/* Interface to Sjoerd's portable C thread library */

#include "Python.h"
#include "pycore_interp.h" // _PyInterpreterState.threads.count
#include "pycore_lock.h"
#include "pycore_moduleobject.h" // _PyModule_GetState()
#include "pycore_modsupport.h" // _PyArg_NoKeywords()
#include "pycore_pylifecycle.h"
Expand Down Expand Up @@ -42,12 +42,32 @@ get_thread_state(PyObject *module)

// _ThreadHandle type

typedef enum {
THREAD_HANDLE_INVALID,
THREAD_HANDLE_JOINED,
THREAD_HANDLE_DETACHED,
} ThreadHandleState;

// A handle to join or detach an OS thread.
//
// Joining or detaching the handle is idempotent; the underlying OS thread is
// joined or detached only once. Concurrent operations block until it is their
// turn to execute or an operation completes successfully. Once an operation
// has completed successfully all future operations complete immediately.
typedef struct {
PyObject_HEAD
struct llist_node node; // linked list node (see _pythread_runtime_state)

// The `ident` and `handle` fields are immutable once the object is visible
// to threads other than its creator, thus they do not need to be accessed
// atomically.
PyThread_ident_t ident;
PyThread_handle_t handle;
char joinable;

// State is set once by the first successful `join` or `detach` operation
// (or if the handle is invalidated).
ThreadHandleState state;
_PyOnceFlag once;
} ThreadHandleObject;

static ThreadHandleObject*
Expand All @@ -59,7 +79,7 @@ new_thread_handle(thread_module_state* state)
}
self->ident = 0;
self->handle = 0;
self->joinable = 0;
self->once = (_PyOnceFlag){0};

HEAD_LOCK(&_PyRuntime);
llist_insert_tail(&_PyRuntime.threads.handles, &self->node);
Expand All @@ -68,6 +88,18 @@ new_thread_handle(thread_module_state* state)
return self;
}

static int
detach_thread(ThreadHandleObject *handle)
{
// This is typically short so no need to release the GIL
if (PyThread_detach_thread(handle->handle)) {
PyErr_SetString(ThreadError, "Failed detaching thread");
return -1;
}
handle->state = THREAD_HANDLE_DETACHED;
return 0;
}

static void
ThreadHandle_dealloc(ThreadHandleObject *self)
{
Expand All @@ -80,17 +112,32 @@ ThreadHandle_dealloc(ThreadHandleObject *self)
}
HEAD_UNLOCK(&_PyRuntime);

if (self->joinable) {
int ret = PyThread_detach_thread(self->handle);
if (ret) {
PyErr_SetString(ThreadError, "Failed detaching thread");
PyErr_WriteUnraisable(tp);
}
if (_PyOnceFlag_CallOnce(&self->once, (_Py_once_fn_t *)detach_thread,
self) == -1) {
PyErr_WriteUnraisable(tp);
}
PyObject_Free(self);
Py_DECREF(tp);
}

static int
do_invalidate_thread_handle(ThreadHandleObject *handle)
{
handle->state = THREAD_HANDLE_INVALID;
return 0;
}

static void
invalidate_thread_handle(ThreadHandleObject *handle)
{
if (_PyOnceFlag_CallOnce(&handle->once,
(_Py_once_fn_t *)do_invalidate_thread_handle,
handle) == -1) {
Py_FatalError("failed invalidating thread handle");
Py_UNREACHABLE();
}
}

void
_PyThread_AfterFork(struct _pythread_runtime_state *state)
{
Expand All @@ -107,8 +154,11 @@ _PyThread_AfterFork(struct _pythread_runtime_state *state)
continue;
}

// Disallow calls to detach() and join() as they could crash.
hobj->joinable = 0;
// Disallow calls to detach() and join() on handles who were not
// previously joined or detached as they could crash. Calls to detach()
// or join() on handles that were successfully joined or detached are
// allowed as they do not perform any unsafe operations.
invalidate_thread_handle(hobj);
llist_remove(node);
}
}
Expand All @@ -126,49 +176,87 @@ ThreadHandle_get_ident(ThreadHandleObject *self, void *ignored)
return PyLong_FromUnsignedLongLong(self->ident);
}

static PyObject *
invalid_handle_error(void)
{
PyErr_SetString(PyExc_ValueError,
"the handle is invalid and thus cannot be detached");
return NULL;
}

static PyObject *
ThreadHandle_detach(ThreadHandleObject *self, void* ignored)
{
if (!self->joinable) {
PyErr_SetString(PyExc_ValueError,
"the thread is not joinable and thus cannot be detached");
if (_PyOnceFlag_CallOnce(&self->once, (_Py_once_fn_t *)detach_thread,
self) == -1) {
return NULL;
}
self->joinable = 0;
// This is typically short so no need to release the GIL
int ret = PyThread_detach_thread(self->handle);
if (ret) {
PyErr_SetString(ThreadError, "Failed detaching thread");
return NULL;

switch (self->state) {
case THREAD_HANDLE_DETACHED: {
Py_RETURN_NONE;
}
case THREAD_HANDLE_INVALID: {
return invalid_handle_error();
}
case THREAD_HANDLE_JOINED: {
PyErr_SetString(
PyExc_ValueError,
"the thread has been joined and thus cannot be detached");
return NULL;
}
default: {
Py_UNREACHABLE();
}
}
Py_RETURN_NONE;
}

static PyObject *
ThreadHandle_join(ThreadHandleObject *self, void* ignored)
static int
join_thread(ThreadHandleObject *handle)
{
if (!self->joinable) {
PyErr_SetString(PyExc_ValueError, "the thread is not joinable");
return NULL;
}
if (self->ident == PyThread_get_thread_ident_ex()) {
if (handle->ident == PyThread_get_thread_ident_ex()) {
// PyThread_join_thread() would deadlock or error out.
PyErr_SetString(ThreadError, "Cannot join current thread");
return NULL;
return -1;
}
// Before actually joining, we must first mark the thread as non-joinable,
// as joining several times simultaneously or sequentially is undefined behavior.
self->joinable = 0;
int ret;

int err;
Py_BEGIN_ALLOW_THREADS
ret = PyThread_join_thread(self->handle);
err = PyThread_join_thread(handle->handle);
Py_END_ALLOW_THREADS
if (ret) {
if (err) {
PyErr_SetString(ThreadError, "Failed joining thread");
return -1;
}
handle->state = THREAD_HANDLE_JOINED;
return 0;
}

static PyObject *
ThreadHandle_join(ThreadHandleObject *self, void* ignored)
{
if (_PyOnceFlag_CallOnce(&self->once, (_Py_once_fn_t *)join_thread,
self) == -1) {
return NULL;
}
Py_RETURN_NONE;

switch (self->state) {
case THREAD_HANDLE_DETACHED: {
PyErr_SetString(
PyExc_ValueError,
"the thread is detached and thus cannot be joined");
return NULL;
}
case THREAD_HANDLE_INVALID: {
return invalid_handle_error();
}
case THREAD_HANDLE_JOINED: {
Py_RETURN_NONE;
}
default: {
Py_UNREACHABLE();
}
}
}

static PyGetSetDef ThreadHandle_getsetlist[] = {
Expand Down Expand Up @@ -1424,12 +1512,12 @@ thread_PyThread_start_joinable_thread(PyObject *module, PyObject *func)
}
if (do_start_new_thread(state, func, args, /*kwargs=*/ NULL, /*joinable=*/ 1,
&hobj->ident, &hobj->handle)) {
invalidate_thread_handle(hobj);
Py_DECREF(args);
Py_DECREF(hobj);
return NULL;
}
Py_DECREF(args);
hobj->joinable = 1;
return (PyObject*) hobj;
}

Expand Down

0 comments on commit cf6491c

Please sign in to comment.