From 1eeef328847b5e64fbb2035bef70971857fafc56 Mon Sep 17 00:00:00 2001 From: Matt Page Date: Wed, 7 Feb 2024 19:36:26 -0800 Subject: [PATCH] Make `_thread.ThreadHandle` thread-safe in free-threaded builds We protect the mutable state of `ThreadHandle` using a `_PyOnceFlag`. Concurrent operations (i.e. `join` or `detach`) on `ThreadHandle` block until it is their turn to execute or an earlier operation succeeds. Once an operation has been applied successfully all future operations complete immediately. The `join()` method is now idempotent. It may be called multiple times but the underlying OS thread will only be joined once. After `join()` succeeds, any future calls to `join()` will succeed immediately. The `detach()` method is also idempotent. It may be called multiple times but the underlying OS thread will only be detached once. After `detach()` succeeds, any future calls to `detach()` will succeed immediately. If the handle is being joined, `detach()` blocks until the join completes. --- Lib/test/test_thread.py | 9 +-- Modules/_threadmodule.c | 168 +++++++++++++++++++++++++++++++--------- 2 files changed, 135 insertions(+), 42 deletions(-) diff --git a/Lib/test/test_thread.py b/Lib/test/test_thread.py index 931cb4b797e0b21..4edc98d882599f6 100644 --- a/Lib/test/test_thread.py +++ b/Lib/test/test_thread.py @@ -189,8 +189,7 @@ def task(): with threading_helper.wait_threads_exit(): handle = thread.start_joinable_thread(task) handle.join() - with self.assertRaisesRegex(ValueError, "not joinable"): - handle.join() + handle.join() def test_joinable_not_joined(self): handle_destroyed = thread.allocate_lock() @@ -255,7 +254,7 @@ def task(): handles.append(handle) start_joinable_thread_returned.release() thread_detached.acquire() - with self.assertRaisesRegex(ValueError, "not joinable"): + with self.assertRaisesRegex(ValueError, "detached and thus cannot be joined"): handle.join() assert len(errors) == 0 @@ -272,7 +271,7 @@ def task(): # detach() returns even though the thread is blocked on lock handle.detach() # join() then cannot be called anymore - with self.assertRaisesRegex(ValueError, "not joinable"): + with self.assertRaisesRegex(ValueError, "detached and thus cannot be joined"): handle.join() lock.release() @@ -283,7 +282,7 @@ def task(): with threading_helper.wait_threads_exit(): handle = thread.start_joinable_thread(task) handle.join() - with self.assertRaisesRegex(ValueError, "not joinable"): + with self.assertRaisesRegex(ValueError, "joined and thus cannot be detached"): handle.detach() diff --git a/Modules/_threadmodule.c b/Modules/_threadmodule.c index df02b023012fbde..0191686138181ff 100644 --- a/Modules/_threadmodule.c +++ b/Modules/_threadmodule.c @@ -1,9 +1,9 @@ - /* Thread module */ /* Interface to Sjoerd's portable C thread library */ #include "Python.h" #include "pycore_interp.h" // _PyInterpreterState.threads.count +#include "pycore_lock.h" #include "pycore_moduleobject.h" // _PyModule_GetState() #include "pycore_modsupport.h" // _PyArg_NoKeywords() #include "pycore_pylifecycle.h" @@ -42,14 +42,49 @@ get_thread_state(PyObject *module) // _ThreadHandle type +typedef enum { + THREAD_HANDLE_INVALID, + THREAD_HANDLE_JOINED, + THREAD_HANDLE_DETACHED, +} ThreadHandleState; + +// A handle to join or detach an OS thread. +// +// Joining or detaching the handle is idempotent; the underlying OS thread is +// joined or detached only once. Concurrent operations block until it is their +// turn to execute or an operation completes successfully. Once an operation +// has completed successfully all future operations complete immediately. typedef struct { PyObject_HEAD struct llist_node node; // linked list node (see _pythread_runtime_state) + + // The `ident` and `handle` fields are immutable once the object is visible + // to threads other than its creator, thus they do not need to be accessed + // atomically. PyThread_ident_t ident; PyThread_handle_t handle; - char joinable; + + // State is set once by the first successful `join` or `detach` operation + // (or if the handle is invalidated). + ThreadHandleState state; + _PyOnceFlag once; } ThreadHandleObject; +// An operation on a ThreadHandle that sets the state and return 0 on success +// or returns -1 on error. +typedef int (*ThreadHandleOp)(ThreadHandleObject *); + +// Maybe execute op on the handle. Returns -1 if op was called and returned -1. +static int +try_thread_handle_op(ThreadHandleOp op, ThreadHandleObject *handle) +{ + if (_PyOnceFlag_CallOnce(&handle->once, (_Py_once_fn_t *)op, handle) == + -1) { + return -1; + } + return handle->state; +} + static ThreadHandleObject* new_thread_handle(thread_module_state* state) { @@ -59,7 +94,7 @@ new_thread_handle(thread_module_state* state) } self->ident = 0; self->handle = 0; - self->joinable = 0; + self->once = (_PyOnceFlag){0}; HEAD_LOCK(&_PyRuntime); llist_insert_tail(&_PyRuntime.threads.handles, &self->node); @@ -68,6 +103,17 @@ new_thread_handle(thread_module_state* state) return self; } +static int +detach_thread(ThreadHandleObject *handle) +{ + // This is typically short so no need to release the GIL + if (PyThread_detach_thread(handle->handle)) { + return -1; + } + handle->state = THREAD_HANDLE_DETACHED; + return 0; +} + static void ThreadHandle_dealloc(ThreadHandleObject *self) { @@ -80,17 +126,30 @@ ThreadHandle_dealloc(ThreadHandleObject *self) } HEAD_UNLOCK(&_PyRuntime); - if (self->joinable) { - int ret = PyThread_detach_thread(self->handle); - if (ret) { - PyErr_SetString(ThreadError, "Failed detaching thread"); - PyErr_WriteUnraisable(tp); - } + if (try_thread_handle_op(detach_thread, self) == -1) { + PyErr_SetString(ThreadError, "Failed detaching thread"); + PyErr_WriteUnraisable(tp); } PyObject_Free(self); Py_DECREF(tp); } +static int +do_invalidate_thread_handle(ThreadHandleObject *handle) +{ + handle->state = THREAD_HANDLE_INVALID; + return 0; +} + +static void +invalidate_thread_handle(ThreadHandleObject *handle) +{ + if (try_thread_handle_op(do_invalidate_thread_handle, handle) == -1) { + Py_FatalError("failed invalidating thread handle"); + Py_UNREACHABLE(); + } +} + void _PyThread_AfterFork(struct _pythread_runtime_state *state) { @@ -108,7 +167,7 @@ _PyThread_AfterFork(struct _pythread_runtime_state *state) } // Disallow calls to detach() and join() as they could crash. - hobj->joinable = 0; + invalidate_thread_handle(hobj); llist_remove(node); } } @@ -126,49 +185,84 @@ ThreadHandle_get_ident(ThreadHandleObject *self, void *ignored) return PyLong_FromUnsignedLongLong(self->ident); } +static PyObject * +invalid_handle_error(void) +{ + PyErr_SetString(PyExc_ValueError, + "the handle is invalid and thus cannot be detached"); + return NULL; +} static PyObject * ThreadHandle_detach(ThreadHandleObject *self, void* ignored) { - if (!self->joinable) { - PyErr_SetString(PyExc_ValueError, - "the thread is not joinable and thus cannot be detached"); - return NULL; + switch (try_thread_handle_op(detach_thread, self)) { + case -1: { + PyErr_SetString(ThreadError, "Failed detaching thread"); + return NULL; + } + case THREAD_HANDLE_INVALID: { + return invalid_handle_error(); + } + case THREAD_HANDLE_JOINED: { + PyErr_SetString( + PyExc_ValueError, + "the thread has been joined and thus cannot be detached"); + return NULL; + } + case THREAD_HANDLE_DETACHED: { + Py_RETURN_NONE; + } + default: { + Py_UNREACHABLE(); + } } - self->joinable = 0; - // This is typically short so no need to release the GIL - int ret = PyThread_detach_thread(self->handle); - if (ret) { - PyErr_SetString(ThreadError, "Failed detaching thread"); - return NULL; +} + +static int +join_thread(ThreadHandleObject *handle) +{ + int err; + Py_BEGIN_ALLOW_THREADS + err = PyThread_join_thread(handle->handle); + Py_END_ALLOW_THREADS + if (err) { + return -1; } - Py_RETURN_NONE; + handle->state = THREAD_HANDLE_JOINED; + return 0; } static PyObject * ThreadHandle_join(ThreadHandleObject *self, void* ignored) { - if (!self->joinable) { - PyErr_SetString(PyExc_ValueError, "the thread is not joinable"); - return NULL; - } if (self->ident == PyThread_get_thread_ident_ex()) { // PyThread_join_thread() would deadlock or error out. PyErr_SetString(ThreadError, "Cannot join current thread"); return NULL; } - // Before actually joining, we must first mark the thread as non-joinable, - // as joining several times simultaneously or sequentially is undefined behavior. - self->joinable = 0; - int ret; - Py_BEGIN_ALLOW_THREADS - ret = PyThread_join_thread(self->handle); - Py_END_ALLOW_THREADS - if (ret) { - PyErr_SetString(ThreadError, "Failed joining thread"); - return NULL; + + switch (try_thread_handle_op(join_thread, self)) { + case -1: { + PyErr_SetString(ThreadError, "Failed joining thread"); + return NULL; + } + case THREAD_HANDLE_INVALID: { + return invalid_handle_error(); + } + case THREAD_HANDLE_JOINED: { + Py_RETURN_NONE; + } + case THREAD_HANDLE_DETACHED: { + PyErr_SetString( + PyExc_ValueError, + "the thread is detached and thus cannot be joined"); + return NULL; + } + default: { + Py_UNREACHABLE(); + } } - Py_RETURN_NONE; } static PyGetSetDef ThreadHandle_getsetlist[] = { @@ -1424,12 +1518,12 @@ thread_PyThread_start_joinable_thread(PyObject *module, PyObject *func) } if (do_start_new_thread(state, func, args, /*kwargs=*/ NULL, /*joinable=*/ 1, &hobj->ident, &hobj->handle)) { + invalidate_thread_handle(hobj); Py_DECREF(args); Py_DECREF(hobj); return NULL; } Py_DECREF(args); - hobj->joinable = 1; return (PyObject*) hobj; }