From 906d618292b728e0d0e508fd3287180caffb3c87 Mon Sep 17 00:00:00 2001 From: Matt Page Date: Wed, 7 Feb 2024 19:36:26 -0800 Subject: [PATCH] Make `_thread.ThreadHandle` thread-safe in free-threaded builds We protect the mutable state of `ThreadHandle` using a `_PyOnceFlag`. Concurrent operations (i.e. `join` or `detach`) on `ThreadHandle` block until it is their turn to execute or an earlier operation succeeds. Once an operation has been applied successfully all future operations complete immediately. The `join()` method is now idempotent. It may be called multiple times but the underlying OS thread will only be joined once. After `join()` succeeds, any future calls to `join()` will succeed immediately. The `detach()` method is also idempotent. It may be called multiple times but the underlying OS thread will only be detached once. After `detach()` succeeds, any future calls to `detach()` will succeed immediately. If the handle is being joined, `detach()` blocks until the join completes. --- Lib/test/test_thread.py | 9 +-- Lib/threading.py | 21 ++---- Modules/_threadmodule.c | 163 +++++++++++++++++++++++++++++++--------- 3 files changed, 135 insertions(+), 58 deletions(-) diff --git a/Lib/test/test_thread.py b/Lib/test/test_thread.py index 931cb4b797e0b21..4edc98d882599f6 100644 --- a/Lib/test/test_thread.py +++ b/Lib/test/test_thread.py @@ -189,8 +189,7 @@ def task(): with threading_helper.wait_threads_exit(): handle = thread.start_joinable_thread(task) handle.join() - with self.assertRaisesRegex(ValueError, "not joinable"): - handle.join() + handle.join() def test_joinable_not_joined(self): handle_destroyed = thread.allocate_lock() @@ -255,7 +254,7 @@ def task(): handles.append(handle) start_joinable_thread_returned.release() thread_detached.acquire() - with self.assertRaisesRegex(ValueError, "not joinable"): + with self.assertRaisesRegex(ValueError, "detached and thus cannot be joined"): handle.join() assert len(errors) == 0 @@ -272,7 +271,7 @@ def task(): # detach() returns even though the thread is blocked on lock handle.detach() # join() then cannot be called anymore - with self.assertRaisesRegex(ValueError, "not joinable"): + with self.assertRaisesRegex(ValueError, "detached and thus cannot be joined"): handle.join() lock.release() @@ -283,7 +282,7 @@ def task(): with threading_helper.wait_threads_exit(): handle = thread.start_joinable_thread(task) handle.join() - with self.assertRaisesRegex(ValueError, "not joinable"): + with self.assertRaisesRegex(ValueError, "joined and thus cannot be detached"): handle.detach() diff --git a/Lib/threading.py b/Lib/threading.py index b6ff00acadd58fe..364c45d840cf2b0 100644 --- a/Lib/threading.py +++ b/Lib/threading.py @@ -956,14 +956,11 @@ def _after_fork(self, new_ident=None): if self._tstate_lock is not None: self._tstate_lock._at_fork_reinit() self._tstate_lock.acquire() - if self._join_lock is not None: - self._join_lock._at_fork_reinit() else: # This thread isn't alive after fork: it doesn't have a tstate # anymore. self._is_stopped = True self._tstate_lock = None - self._join_lock = None self._handle = None def __repr__(self): @@ -996,8 +993,6 @@ def start(self): if self._started.is_set(): raise RuntimeError("threads can only be started once") - self._join_lock = _allocate_lock() - with _active_limbo_lock: _limbo[self] = self try: @@ -1167,17 +1162,7 @@ def join(self, timeout=None): self._join_os_thread() def _join_os_thread(self): - join_lock = self._join_lock - if join_lock is None: - return - with join_lock: - # Calling join() multiple times would raise an exception - # in one of the callers. - if self._handle is not None: - self._handle.join() - self._handle = None - # No need to keep this around - self._join_lock = None + self._handle.join() def _wait_for_tstate_lock(self, block=True, timeout=-1): # Issue #18808: wait for the thread state to be gone. @@ -1478,6 +1463,10 @@ def __init__(self): with _active_limbo_lock: _active[self._ident] = self + def _join_os_thread(self): + # No ThreadHandle for main thread + pass + # Helper thread-local instance to detect when a _DummyThread # is collected. Not a part of the public API. diff --git a/Modules/_threadmodule.c b/Modules/_threadmodule.c index df02b023012fbde..a2333427685d170 100644 --- a/Modules/_threadmodule.c +++ b/Modules/_threadmodule.c @@ -1,9 +1,9 @@ - /* Thread module */ /* Interface to Sjoerd's portable C thread library */ #include "Python.h" #include "pycore_interp.h" // _PyInterpreterState.threads.count +#include "pycore_lock.h" #include "pycore_moduleobject.h" // _PyModule_GetState() #include "pycore_modsupport.h" // _PyArg_NoKeywords() #include "pycore_pylifecycle.h" @@ -42,12 +42,32 @@ get_thread_state(PyObject *module) // _ThreadHandle type +typedef enum { + THREAD_HANDLE_INVALID, + THREAD_HANDLE_JOINED, + THREAD_HANDLE_DETACHED, +} ThreadHandleState; + +// A handle to join or detach an OS thread. +// +// Joining or detaching the handle is idempotent; the underlying OS thread is +// joined or detached only once. Concurrent operations block until it is their +// turn to execute or an operation completes successfully. Once an operation +// has completed successfully all future operations complete immediately. typedef struct { PyObject_HEAD struct llist_node node; // linked list node (see _pythread_runtime_state) + + // The `ident` and `handle` fields are immutable once the object is visible + // to threads other than its creator, thus they do not need to be accessed + // atomically. PyThread_ident_t ident; PyThread_handle_t handle; - char joinable; + + // State is set once by the first successful `join` or `detach` operation + // (or if the handle is invalidated). + ThreadHandleState state; + _PyOnceFlag once; } ThreadHandleObject; static ThreadHandleObject* @@ -59,7 +79,8 @@ new_thread_handle(thread_module_state* state) } self->ident = 0; self->handle = 0; - self->joinable = 0; + self->state = THREAD_HANDLE_INVALID; + self->once = (_PyOnceFlag){0}; HEAD_LOCK(&_PyRuntime); llist_insert_tail(&_PyRuntime.threads.handles, &self->node); @@ -68,6 +89,18 @@ new_thread_handle(thread_module_state* state) return self; } +static int +detach_thread(ThreadHandleObject *handle) +{ + // This is typically short so no need to release the GIL + if (PyThread_detach_thread(handle->handle)) { + PyErr_SetString(ThreadError, "Failed detaching thread"); + return -1; + } + handle->state = THREAD_HANDLE_DETACHED; + return 0; +} + static void ThreadHandle_dealloc(ThreadHandleObject *self) { @@ -80,17 +113,32 @@ ThreadHandle_dealloc(ThreadHandleObject *self) } HEAD_UNLOCK(&_PyRuntime); - if (self->joinable) { - int ret = PyThread_detach_thread(self->handle); - if (ret) { - PyErr_SetString(ThreadError, "Failed detaching thread"); - PyErr_WriteUnraisable(tp); - } + if (_PyOnceFlag_CallOnce(&self->once, (_Py_once_fn_t *)detach_thread, + self) == -1) { + PyErr_WriteUnraisable(tp); } PyObject_Free(self); Py_DECREF(tp); } +static int +do_invalidate_thread_handle(ThreadHandleObject *handle) +{ + handle->state = THREAD_HANDLE_INVALID; + return 0; +} + +static void +invalidate_thread_handle(ThreadHandleObject *handle) +{ + if (_PyOnceFlag_CallOnce(&handle->once, + (_Py_once_fn_t *)do_invalidate_thread_handle, + handle) == -1) { + Py_FatalError("failed invalidating thread handle"); + Py_UNREACHABLE(); + } +} + void _PyThread_AfterFork(struct _pythread_runtime_state *state) { @@ -107,8 +155,11 @@ _PyThread_AfterFork(struct _pythread_runtime_state *state) continue; } - // Disallow calls to detach() and join() as they could crash. - hobj->joinable = 0; + // Disallow calls to detach() and join() on handles who were not + // previously joined or detached as they could crash. Calls to detach() + // or join() on handles that were successfully joined or detached are + // allowed as they do not perform any unsafe operations. + invalidate_thread_handle(hobj); llist_remove(node); } } @@ -126,49 +177,87 @@ ThreadHandle_get_ident(ThreadHandleObject *self, void *ignored) return PyLong_FromUnsignedLongLong(self->ident); } +static PyObject * +invalid_handle_error(void) +{ + PyErr_SetString(PyExc_ValueError, + "the handle is invalid and thus cannot be detached"); + return NULL; +} static PyObject * ThreadHandle_detach(ThreadHandleObject *self, void* ignored) { - if (!self->joinable) { - PyErr_SetString(PyExc_ValueError, - "the thread is not joinable and thus cannot be detached"); + if (_PyOnceFlag_CallOnce(&self->once, (_Py_once_fn_t *)detach_thread, + self) == -1) { return NULL; } - self->joinable = 0; - // This is typically short so no need to release the GIL - int ret = PyThread_detach_thread(self->handle); - if (ret) { - PyErr_SetString(ThreadError, "Failed detaching thread"); - return NULL; + + switch (self->state) { + case THREAD_HANDLE_DETACHED: { + Py_RETURN_NONE; + } + case THREAD_HANDLE_INVALID: { + return invalid_handle_error(); + } + case THREAD_HANDLE_JOINED: { + PyErr_SetString( + PyExc_ValueError, + "the thread has been joined and thus cannot be detached"); + return NULL; + } + default: { + Py_UNREACHABLE(); + } } - Py_RETURN_NONE; } -static PyObject * -ThreadHandle_join(ThreadHandleObject *self, void* ignored) +static int +join_thread(ThreadHandleObject *handle) { - if (!self->joinable) { - PyErr_SetString(PyExc_ValueError, "the thread is not joinable"); - return NULL; - } - if (self->ident == PyThread_get_thread_ident_ex()) { + if (handle->ident == PyThread_get_thread_ident_ex()) { // PyThread_join_thread() would deadlock or error out. PyErr_SetString(ThreadError, "Cannot join current thread"); - return NULL; + return -1; } - // Before actually joining, we must first mark the thread as non-joinable, - // as joining several times simultaneously or sequentially is undefined behavior. - self->joinable = 0; - int ret; + + int err; Py_BEGIN_ALLOW_THREADS - ret = PyThread_join_thread(self->handle); + err = PyThread_join_thread(handle->handle); Py_END_ALLOW_THREADS - if (ret) { + if (err) { PyErr_SetString(ThreadError, "Failed joining thread"); + return -1; + } + handle->state = THREAD_HANDLE_JOINED; + return 0; +} + +static PyObject * +ThreadHandle_join(ThreadHandleObject *self, void* ignored) +{ + if (_PyOnceFlag_CallOnce(&self->once, (_Py_once_fn_t *)join_thread, + self) == -1) { return NULL; } - Py_RETURN_NONE; + + switch (self->state) { + case THREAD_HANDLE_DETACHED: { + PyErr_SetString( + PyExc_ValueError, + "the thread is detached and thus cannot be joined"); + return NULL; + } + case THREAD_HANDLE_INVALID: { + return invalid_handle_error(); + } + case THREAD_HANDLE_JOINED: { + Py_RETURN_NONE; + } + default: { + Py_UNREACHABLE(); + } + } } static PyGetSetDef ThreadHandle_getsetlist[] = { @@ -1424,12 +1513,12 @@ thread_PyThread_start_joinable_thread(PyObject *module, PyObject *func) } if (do_start_new_thread(state, func, args, /*kwargs=*/ NULL, /*joinable=*/ 1, &hobj->ident, &hobj->handle)) { + invalidate_thread_handle(hobj); Py_DECREF(args); Py_DECREF(hobj); return NULL; } Py_DECREF(args); - hobj->joinable = 1; return (PyObject*) hobj; }