From 92abb0124037e5bc938fa870461a26f64c56095b Mon Sep 17 00:00:00 2001 From: Dino Viehland Date: Tue, 6 Feb 2024 14:03:43 -0800 Subject: [PATCH] gh-112075: Add critical sections for most dict APIs (#114508) Starts adding thread safety to dict objects. Use @critical_section for APIs which are exposed via argument clinic and don't directly correlate with a public C API which needs to acquire the lock Use a _lock_held suffix for keeping changes to complicated functions simple and just wrapping them with a critical section Acquire and release the lock in an existing function where it won't be overly disruptive to the existing logic --- Include/internal/pycore_critical_section.h | 46 ++ Modules/_sre/sre.c | 27 +- Objects/clinic/dictobject.c.h | 30 +- Objects/dictobject.c | 866 +++++++++++++++------ Objects/odictobject.c | 19 +- Objects/setobject.c | 78 +- 6 files changed, 782 insertions(+), 284 deletions(-) diff --git a/Include/internal/pycore_critical_section.h b/Include/internal/pycore_critical_section.h index bf2bbfffc38bd0..38ed8cd69804ba 100644 --- a/Include/internal/pycore_critical_section.h +++ b/Include/internal/pycore_critical_section.h @@ -104,12 +104,37 @@ extern "C" { # define Py_END_CRITICAL_SECTION2() \ _PyCriticalSection2_End(&_cs2); \ } + +// Asserts that the mutex is locked. The mutex must be held by the +// top-most critical section otherwise there's the possibility +// that the mutex would be swalled out in some code paths. +#define _Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(mutex) \ + _PyCriticalSection_AssertHeld(mutex) + +// Asserts that the mutex for the given object is locked. The mutex must +// be held by the top-most critical section otherwise there's the +// possibility that the mutex would be swalled out in some code paths. +#ifdef Py_DEBUG + +#define _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(op) \ + if (Py_REFCNT(op) != 1) { \ + _Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(&_PyObject_CAST(op)->ob_mutex); \ + } + +#else /* Py_DEBUG */ + +#define _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(op) + +#endif /* Py_DEBUG */ + #else /* !Py_GIL_DISABLED */ // The critical section APIs are no-ops with the GIL. # define Py_BEGIN_CRITICAL_SECTION(op) # define Py_END_CRITICAL_SECTION() # define Py_BEGIN_CRITICAL_SECTION2(a, b) # define Py_END_CRITICAL_SECTION2() +# define _Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(mutex) +# define _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(op) #endif /* !Py_GIL_DISABLED */ typedef struct { @@ -236,6 +261,27 @@ _PyCriticalSection2_End(_PyCriticalSection2 *c) PyAPI_FUNC(void) _PyCriticalSection_SuspendAll(PyThreadState *tstate); +#ifdef Py_GIL_DISABLED + +static inline void +_PyCriticalSection_AssertHeld(PyMutex *mutex) { +#ifdef Py_DEBUG + PyThreadState *tstate = _PyThreadState_GET(); + uintptr_t prev = tstate->critical_section; + if (prev & _Py_CRITICAL_SECTION_TWO_MUTEXES) { + _PyCriticalSection2 *cs = (_PyCriticalSection2 *)(prev & ~_Py_CRITICAL_SECTION_MASK); + assert(cs != NULL && (cs->base.mutex == mutex || cs->mutex2 == mutex)); + } + else { + _PyCriticalSection *cs = (_PyCriticalSection *)(tstate->critical_section & ~_Py_CRITICAL_SECTION_MASK); + assert(cs != NULL && cs->mutex == mutex); + } + +#endif +} + +#endif + #ifdef __cplusplus } #endif diff --git a/Modules/_sre/sre.c b/Modules/_sre/sre.c index d451974b9cf81e..00fbd9674b8cdd 100644 --- a/Modules/_sre/sre.c +++ b/Modules/_sre/sre.c @@ -39,13 +39,14 @@ static const char copyright[] = " SRE 2.2.2 Copyright (c) 1997-2002 by Secret Labs AB "; #include "Python.h" -#include "pycore_dict.h" // _PyDict_Next() -#include "pycore_long.h" // _PyLong_GetZero() -#include "pycore_moduleobject.h" // _PyModule_GetState() +#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION +#include "pycore_dict.h" // _PyDict_Next() +#include "pycore_long.h" // _PyLong_GetZero() +#include "pycore_moduleobject.h" // _PyModule_GetState() -#include "sre.h" // SRE_CODE +#include "sre.h" // SRE_CODE -#include // tolower(), toupper(), isalnum() +#include // tolower(), toupper(), isalnum() #define SRE_CODE_BITS (8 * sizeof(SRE_CODE)) @@ -2349,26 +2350,28 @@ _sre_SRE_Match_groupdict_impl(MatchObject *self, PyObject *default_value) if (!result || !self->pattern->groupindex) return result; + Py_BEGIN_CRITICAL_SECTION(self->pattern->groupindex); while (_PyDict_Next(self->pattern->groupindex, &pos, &key, &value, &hash)) { int status; Py_INCREF(key); value = match_getslice(self, key, default_value); if (!value) { Py_DECREF(key); - goto failed; + Py_CLEAR(result); + goto exit; } status = _PyDict_SetItem_KnownHash(result, key, value, hash); Py_DECREF(value); Py_DECREF(key); - if (status < 0) - goto failed; + if (status < 0) { + Py_CLEAR(result); + goto exit; + } } +exit: + Py_END_CRITICAL_SECTION(); return result; - -failed: - Py_DECREF(result); - return NULL; } /*[clinic input] diff --git a/Objects/clinic/dictobject.c.h b/Objects/clinic/dictobject.c.h index 8f532f454156de..daaef211b1db49 100644 --- a/Objects/clinic/dictobject.c.h +++ b/Objects/clinic/dictobject.c.h @@ -2,6 +2,7 @@ preserve [clinic start generated code]*/ +#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION() #include "pycore_modsupport.h" // _PyArg_CheckPositional() PyDoc_STRVAR(dict_fromkeys__doc__, @@ -65,6 +66,21 @@ PyDoc_STRVAR(dict___contains____doc__, #define DICT___CONTAINS___METHODDEF \ {"__contains__", (PyCFunction)dict___contains__, METH_O|METH_COEXIST, dict___contains____doc__}, +static PyObject * +dict___contains___impl(PyDictObject *self, PyObject *key); + +static PyObject * +dict___contains__(PyDictObject *self, PyObject *key) +{ + PyObject *return_value = NULL; + + Py_BEGIN_CRITICAL_SECTION(self); + return_value = dict___contains___impl(self, key); + Py_END_CRITICAL_SECTION(); + + return return_value; +} + PyDoc_STRVAR(dict_get__doc__, "get($self, key, default=None, /)\n" "--\n" @@ -93,7 +109,9 @@ dict_get(PyDictObject *self, PyObject *const *args, Py_ssize_t nargs) } default_value = args[1]; skip_optional: + Py_BEGIN_CRITICAL_SECTION(self); return_value = dict_get_impl(self, key, default_value); + Py_END_CRITICAL_SECTION(); exit: return return_value; @@ -130,7 +148,9 @@ dict_setdefault(PyDictObject *self, PyObject *const *args, Py_ssize_t nargs) } default_value = args[1]; skip_optional: + Py_BEGIN_CRITICAL_SECTION(self); return_value = dict_setdefault_impl(self, key, default_value); + Py_END_CRITICAL_SECTION(); exit: return return_value; @@ -209,7 +229,13 @@ dict_popitem_impl(PyDictObject *self); static PyObject * dict_popitem(PyDictObject *self, PyObject *Py_UNUSED(ignored)) { - return dict_popitem_impl(self); + PyObject *return_value = NULL; + + Py_BEGIN_CRITICAL_SECTION(self); + return_value = dict_popitem_impl(self); + Py_END_CRITICAL_SECTION(); + + return return_value; } PyDoc_STRVAR(dict___sizeof____doc__, @@ -301,4 +327,4 @@ dict_values(PyDictObject *self, PyObject *Py_UNUSED(ignored)) { return dict_values_impl(self); } -/*[clinic end generated code: output=f3ac47dfbf341b23 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=c8fda06bac5b05f3 input=a9049054013a1b77]*/ diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 11b388d9f4adb0..2df95e977a180f 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -113,18 +113,19 @@ As a consequence of this, split keys have a maximum size of 16. #define PyDict_MINSIZE 8 #include "Python.h" -#include "pycore_bitutils.h" // _Py_bit_length -#include "pycore_call.h" // _PyObject_CallNoArgs() -#include "pycore_ceval.h" // _PyEval_GetBuiltin() -#include "pycore_code.h" // stats -#include "pycore_dict.h" // export _PyDict_SizeOf() -#include "pycore_freelist.h" // _PyFreeListState_GET() -#include "pycore_gc.h" // _PyObject_GC_IS_TRACKED() -#include "pycore_object.h" // _PyObject_GC_TRACK(), _PyDebugAllocatorStats() -#include "pycore_pyerrors.h" // _PyErr_GetRaisedException() -#include "pycore_pystate.h" // _PyThreadState_GET() -#include "pycore_setobject.h" // _PySet_NextEntry() -#include "stringlib/eq.h" // unicode_eq() +#include "pycore_bitutils.h" // _Py_bit_length +#include "pycore_call.h" // _PyObject_CallNoArgs() +#include "pycore_ceval.h" // _PyEval_GetBuiltin() +#include "pycore_code.h" // stats +#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION, Py_END_CRITICAL_SECTION +#include "pycore_dict.h" // export _PyDict_SizeOf() +#include "pycore_freelist.h" // _PyFreeListState_GET() +#include "pycore_gc.h" // _PyObject_GC_IS_TRACKED() +#include "pycore_object.h" // _PyObject_GC_TRACK(), _PyDebugAllocatorStats() +#include "pycore_pyerrors.h" // _PyErr_GetRaisedException() +#include "pycore_pystate.h" // _PyThreadState_GET() +#include "pycore_setobject.h" // _PySet_NextEntry() +#include "stringlib/eq.h" // unicode_eq() #include @@ -141,6 +142,21 @@ To avoid slowing down lookups on a near-full table, we resize the table when it's USABLE_FRACTION (currently two-thirds) full. */ +#ifdef Py_GIL_DISABLED + +static inline void +ASSERT_DICT_LOCKED(PyObject *op) +{ + _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(op); +} +#define ASSERT_DICT_LOCKED(op) ASSERT_DICT_LOCKED(_Py_CAST(PyObject*, op)) + +#else + +#define ASSERT_DICT_LOCKED(op) + +#endif + #define PERTURB_SHIFT 5 /* @@ -240,6 +256,16 @@ static int dictresize(PyInterpreterState *interp, PyDictObject *mp, static PyObject* dict_iter(PyObject *dict); +static int +contains_lock_held(PyDictObject *mp, PyObject *key); +static int +contains_known_hash_lock_held(PyDictObject *mp, PyObject *key, Py_ssize_t hash); +static int +setitem_lock_held(PyDictObject *mp, PyObject *key, PyObject *value); +static int +dict_setdefault_ref_lock_held(PyObject *d, PyObject *key, PyObject *default_value, + PyObject **result, int incref_result); + #include "clinic/dictobject.c.h" @@ -789,6 +815,8 @@ clone_combined_dict_keys(PyDictObject *orig) assert(orig->ma_keys != Py_EMPTY_KEYS); assert(orig->ma_keys->dk_refcnt == 1); + ASSERT_DICT_LOCKED(orig); + size_t keys_size = _PyDict_KeysSize(orig->ma_keys); PyDictKeysObject *keys = PyMem_Malloc(keys_size); if (keys == NULL) { @@ -1230,6 +1258,8 @@ insertdict(PyInterpreterState *interp, PyDictObject *mp, { PyObject *old_value; + ASSERT_DICT_LOCKED(mp); + if (DK_IS_UNICODE(mp->ma_keys) && !PyUnicode_CheckExact(key)) { if (insertion_resize(interp, mp, 0) < 0) goto Fail; @@ -1326,6 +1356,7 @@ insert_to_emptydict(PyInterpreterState *interp, PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value) { assert(mp->ma_keys == Py_EMPTY_KEYS); + ASSERT_DICT_LOCKED(mp); uint64_t new_version = _PyDict_NotifyEvent( interp, PyDict_EVENT_ADDED, mp, key, value); @@ -1419,6 +1450,8 @@ dictresize(PyInterpreterState *interp, PyDictObject *mp, PyDictKeysObject *oldkeys; PyDictValues *oldvalues; + ASSERT_DICT_LOCKED(mp); + if (log2_newsize >= SIZEOF_SIZE_T*8) { PyErr_NoMemory(); return -1; @@ -1613,7 +1646,7 @@ _PyDict_FromItems(PyObject *const *keys, Py_ssize_t keys_offset, for (Py_ssize_t i = 0; i < length; i++) { PyObject *key = *ks; PyObject *value = *vs; - if (PyDict_SetItem(dict, key, value) < 0) { + if (setitem_lock_held((PyDictObject *)dict, key, value) < 0) { Py_DECREF(dict); return NULL; } @@ -1688,6 +1721,7 @@ PyDict_GetItem(PyObject *op, PyObject *key) Py_ssize_t _PyDict_LookupIndex(PyDictObject *mp, PyObject *key) { + // TODO: Thread safety PyObject *value; assert(PyDict_CheckExact((PyObject*)mp)); assert(PyUnicode_CheckExact(key)); @@ -1864,9 +1898,11 @@ _PyDict_LoadGlobal(PyDictObject *globals, PyDictObject *builtins, PyObject *key) } /* Consumes references to key and value */ -int -_PyDict_SetItem_Take2(PyDictObject *mp, PyObject *key, PyObject *value) +static int +setitem_take2_lock_held(PyDictObject *mp, PyObject *key, PyObject *value) { + ASSERT_DICT_LOCKED(mp); + assert(key); assert(value); assert(PyDict_Check(mp)); @@ -1879,7 +1915,9 @@ _PyDict_SetItem_Take2(PyDictObject *mp, PyObject *key, PyObject *value) return -1; } } + PyInterpreterState *interp = _PyInterpreterState_GET(); + if (mp->ma_keys == Py_EMPTY_KEYS) { return insert_to_emptydict(interp, mp, key, hash, value); } @@ -1887,6 +1925,16 @@ _PyDict_SetItem_Take2(PyDictObject *mp, PyObject *key, PyObject *value) return insertdict(interp, mp, key, hash, value); } +int +_PyDict_SetItem_Take2(PyDictObject *mp, PyObject *key, PyObject *value) +{ + int res; + Py_BEGIN_CRITICAL_SECTION(mp); + res = setitem_take2_lock_held(mp, key, value); + Py_END_CRITICAL_SECTION(); + return res; +} + /* CAUTION: PyDict_SetItem() must guarantee that it won't resize the * dictionary if it's merely replacing the value for an existing key. * This means that it's safe to loop over a dictionary with PyDict_Next() @@ -1906,6 +1954,16 @@ PyDict_SetItem(PyObject *op, PyObject *key, PyObject *value) Py_NewRef(key), Py_NewRef(value)); } +static int +setitem_lock_held(PyDictObject *mp, PyObject *key, PyObject *value) +{ + assert(key); + assert(value); + return setitem_take2_lock_held(mp, + Py_NewRef(key), Py_NewRef(value)); +} + + int _PyDict_SetItem_KnownHash(PyObject *op, PyObject *key, PyObject *value, Py_hash_t hash) @@ -1921,12 +1979,21 @@ _PyDict_SetItem_KnownHash(PyObject *op, PyObject *key, PyObject *value, assert(hash != -1); mp = (PyDictObject *)op; + int res; PyInterpreterState *interp = _PyInterpreterState_GET(); + + Py_BEGIN_CRITICAL_SECTION(mp); + if (mp->ma_keys == Py_EMPTY_KEYS) { - return insert_to_emptydict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value)); + res = insert_to_emptydict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value)); } - /* insertdict() handles any resizing that might be necessary */ - return insertdict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value)); + else { + /* insertdict() handles any resizing that might be necessary */ + res = insertdict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value)); + } + + Py_END_CRITICAL_SECTION(); + return res; } static void @@ -1951,6 +2018,8 @@ delitem_common(PyDictObject *mp, Py_hash_t hash, Py_ssize_t ix, { PyObject *old_key; + ASSERT_DICT_LOCKED(mp); + Py_ssize_t hashpos = lookdict_index(mp->ma_keys, hash, ix); assert(hashpos >= 0); @@ -2002,8 +2071,8 @@ PyDict_DelItem(PyObject *op, PyObject *key) return _PyDict_DelItem_KnownHash(op, key, hash); } -int -_PyDict_DelItem_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash) +static int +delitem_knownhash_lock_held(PyObject *op, PyObject *key, Py_hash_t hash) { Py_ssize_t ix; PyDictObject *mp; @@ -2013,6 +2082,9 @@ _PyDict_DelItem_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash) PyErr_BadInternalCall(); return -1; } + + ASSERT_DICT_LOCKED(op); + assert(key); assert(hash != -1); mp = (PyDictObject *)op; @@ -2030,13 +2102,19 @@ _PyDict_DelItem_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash) return delitem_common(mp, hash, ix, old_value, new_version); } -/* This function promises that the predicate -> deletion sequence is atomic - * (i.e. protected by the GIL), assuming the predicate itself doesn't - * release the GIL. - */ int -_PyDict_DelItemIf(PyObject *op, PyObject *key, - int (*predicate)(PyObject *value)) +_PyDict_DelItem_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash) +{ + int res; + Py_BEGIN_CRITICAL_SECTION(op); + res = delitem_knownhash_lock_held(op, key, hash); + Py_END_CRITICAL_SECTION(); + return res; +} + +static int +delitemif_lock_held(PyObject *op, PyObject *key, + int (*predicate)(PyObject *value)) { Py_ssize_t hashpos, ix; PyDictObject *mp; @@ -2044,6 +2122,8 @@ _PyDict_DelItemIf(PyObject *op, PyObject *key, PyObject *old_value; int res; + ASSERT_DICT_LOCKED(op); + if (!PyDict_Check(op)) { PyErr_BadInternalCall(); return -1; @@ -2077,16 +2157,32 @@ _PyDict_DelItemIf(PyObject *op, PyObject *key, return 0; } } +/* This function promises that the predicate -> deletion sequence is atomic + * (i.e. protected by the GIL or the per-dict mutex in free threaded builds), + * assuming the predicate itself doesn't release the GIL (or cause re-entrancy + * which would release the per-dict mutex) + */ +int +_PyDict_DelItemIf(PyObject *op, PyObject *key, + int (*predicate)(PyObject *value)) +{ + int res; + Py_BEGIN_CRITICAL_SECTION(op); + res = delitemif_lock_held(op, key, predicate); + Py_END_CRITICAL_SECTION(); + return res; +} - -void -PyDict_Clear(PyObject *op) +static void +clear_lock_held(PyObject *op) { PyDictObject *mp; PyDictKeysObject *oldkeys; PyDictValues *oldvalues; Py_ssize_t i, n; + ASSERT_DICT_LOCKED(op); + if (!PyDict_Check(op)) return; mp = ((PyDictObject *)op); @@ -2119,6 +2215,14 @@ PyDict_Clear(PyObject *op) ASSERT_CONSISTENT(mp); } +void +PyDict_Clear(PyObject *op) +{ + Py_BEGIN_CRITICAL_SECTION(op); + clear_lock_held(op); + Py_END_CRITICAL_SECTION(); +} + /* Internal version of PyDict_Next that returns a hash value in addition * to the key and value. * Return 1 on success, return 0 when the reached the end of the dictionary @@ -2135,6 +2239,9 @@ _PyDict_Next(PyObject *op, Py_ssize_t *ppos, PyObject **pkey, if (!PyDict_Check(op)) return 0; + + ASSERT_DICT_LOCKED(op); + mp = (PyDictObject *)op; i = *ppos; if (mp->ma_values) { @@ -2208,7 +2315,11 @@ _PyDict_Next(PyObject *op, Py_ssize_t *ppos, PyObject **pkey, int PyDict_Next(PyObject *op, Py_ssize_t *ppos, PyObject **pkey, PyObject **pvalue) { - return _PyDict_Next(op, ppos, pkey, pvalue, NULL); + int res; + Py_BEGIN_CRITICAL_SECTION(op); + res = _PyDict_Next(op, ppos, pkey, pvalue, NULL); + Py_END_CRITICAL_SECTION(); + return res; } @@ -2219,6 +2330,8 @@ _PyDict_Pop_KnownHash(PyDictObject *mp, PyObject *key, Py_hash_t hash, { assert(PyDict_Check(mp)); + ASSERT_DICT_LOCKED(mp); + if (mp->ma_used == 0) { if (result) { *result = NULL; @@ -2258,10 +2371,11 @@ _PyDict_Pop_KnownHash(PyDictObject *mp, PyObject *key, Py_hash_t hash, return 1; } - -int -PyDict_Pop(PyObject *op, PyObject *key, PyObject **result) +static int +pop_lock_held(PyObject *op, PyObject *key, PyObject **result) { + ASSERT_DICT_LOCKED(op); + if (!PyDict_Check(op)) { if (result) { *result = NULL; @@ -2291,6 +2405,17 @@ PyDict_Pop(PyObject *op, PyObject *key, PyObject **result) return _PyDict_Pop_KnownHash(dict, key, hash, result); } +int +PyDict_Pop(PyObject *op, PyObject *key, PyObject **result) +{ + int err; + Py_BEGIN_CRITICAL_SECTION(op); + err = pop_lock_held(op, key, result); + Py_END_CRITICAL_SECTION(); + + return err; +} + int PyDict_PopString(PyObject *op, const char *key, PyObject **result) @@ -2323,6 +2448,55 @@ _PyDict_Pop(PyObject *dict, PyObject *key, PyObject *default_value) return result; } +static PyDictObject * +dict_dict_fromkeys(PyInterpreterState *interp, PyDictObject *mp, + PyObject *iterable, PyObject *value) +{ + PyObject *oldvalue; + Py_ssize_t pos = 0; + PyObject *key; + Py_hash_t hash; + int unicode = DK_IS_UNICODE(((PyDictObject*)iterable)->ma_keys); + uint8_t new_size = Py_MAX( + estimate_log2_keysize(PyDict_GET_SIZE(iterable)), + DK_LOG_SIZE(mp->ma_keys)); + if (dictresize(interp, mp, new_size, unicode)) { + Py_DECREF(mp); + return NULL; + } + + while (_PyDict_Next(iterable, &pos, &key, &oldvalue, &hash)) { + if (insertdict(interp, mp, + Py_NewRef(key), hash, Py_NewRef(value))) { + Py_DECREF(mp); + return NULL; + } + } + return mp; +} + +static PyDictObject * +dict_set_fromkeys(PyInterpreterState *interp, PyDictObject *mp, + PyObject *iterable, PyObject *value) +{ + Py_ssize_t pos = 0; + PyObject *key; + Py_hash_t hash; + + if (dictresize(interp, mp, + estimate_log2_keysize(PySet_GET_SIZE(iterable)), 0)) { + Py_DECREF(mp); + return NULL; + } + + while (_PySet_NextEntry(iterable, &pos, &key, &hash)) { + if (insertdict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value))) { + Py_DECREF(mp); + return NULL; + } + } + return mp; +} /* Internal version of dict.from_keys(). It is subclass-friendly. */ PyObject * @@ -2338,49 +2512,22 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value) if (d == NULL) return NULL; - if (PyDict_CheckExact(d) && ((PyDictObject *)d)->ma_used == 0) { + + if (PyDict_CheckExact(d)) { if (PyDict_CheckExact(iterable)) { PyDictObject *mp = (PyDictObject *)d; - PyObject *oldvalue; - Py_ssize_t pos = 0; - PyObject *key; - Py_hash_t hash; - - int unicode = DK_IS_UNICODE(((PyDictObject*)iterable)->ma_keys); - if (dictresize(interp, mp, - estimate_log2_keysize(PyDict_GET_SIZE(iterable)), - unicode)) { - Py_DECREF(d); - return NULL; - } - while (_PyDict_Next(iterable, &pos, &key, &oldvalue, &hash)) { - if (insertdict(interp, mp, - Py_NewRef(key), hash, Py_NewRef(value))) { - Py_DECREF(d); - return NULL; - } - } + Py_BEGIN_CRITICAL_SECTION2(d, iterable); + d = (PyObject *)dict_dict_fromkeys(interp, mp, iterable, value); + Py_END_CRITICAL_SECTION2(); return d; } - if (PyAnySet_CheckExact(iterable)) { + else if (PyAnySet_CheckExact(iterable)) { PyDictObject *mp = (PyDictObject *)d; - Py_ssize_t pos = 0; - PyObject *key; - Py_hash_t hash; - - if (dictresize(interp, mp, - estimate_log2_keysize(PySet_GET_SIZE(iterable)), 0)) { - Py_DECREF(d); - return NULL; - } - while (_PySet_NextEntry(iterable, &pos, &key, &hash)) { - if (insertdict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value))) { - Py_DECREF(d); - return NULL; - } - } + Py_BEGIN_CRITICAL_SECTION2(d, iterable); + d = (PyObject *)dict_set_fromkeys(interp, mp, iterable, value); + Py_END_CRITICAL_SECTION2(); return d; } } @@ -2392,12 +2539,17 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value) } if (PyDict_CheckExact(d)) { + Py_BEGIN_CRITICAL_SECTION(d); while ((key = PyIter_Next(it)) != NULL) { - status = PyDict_SetItem(d, key, value); + status = setitem_lock_held((PyDictObject *)d, key, value); Py_DECREF(key); - if (status < 0) - goto Fail; + if (status < 0) { + assert(PyErr_Occurred()); + goto dict_iter_exit; + } } +dict_iter_exit: + Py_END_CRITICAL_SECTION(); } else { while ((key = PyIter_Next(it)) != NULL) { status = PyObject_SetItem(d, key, value); @@ -2468,7 +2620,7 @@ dict_dealloc(PyObject *self) static PyObject * -dict_repr(PyObject *self) +dict_repr_lock_held(PyObject *self) { PyDictObject *mp = (PyDictObject *)self; Py_ssize_t i; @@ -2498,7 +2650,7 @@ dict_repr(PyObject *self) Note that repr may mutate the dict. */ i = 0; first = 1; - while (PyDict_Next((PyObject *)mp, &i, &key, &value)) { + while (_PyDict_Next((PyObject *)mp, &i, &key, &value, NULL)) { PyObject *s; int res; @@ -2551,15 +2703,25 @@ dict_repr(PyObject *self) return NULL; } +static PyObject * +dict_repr(PyObject *self) +{ + PyObject *res; + Py_BEGIN_CRITICAL_SECTION(self); + res = dict_repr_lock_held(self); + Py_END_CRITICAL_SECTION(); + return res; +} + static Py_ssize_t dict_length(PyObject *self) { PyDictObject *mp = (PyDictObject *)self; - return mp->ma_used; + return _Py_atomic_load_ssize_relaxed(&mp->ma_used); } static PyObject * -dict_subscript(PyObject *self, PyObject *key) +dict_subscript_lock_held(PyObject *self, PyObject *key) { PyDictObject *mp = (PyDictObject *)self; Py_ssize_t ix; @@ -2594,6 +2756,16 @@ dict_subscript(PyObject *self, PyObject *key) return Py_NewRef(value); } +static PyObject * +dict_subscript(PyObject *self, PyObject *key) +{ + PyObject *res; + Py_BEGIN_CRITICAL_SECTION(self); + res = dict_subscript_lock_held(self, key); + Py_END_CRITICAL_SECTION(); + return res; +} + static int dict_ass_sub(PyObject *mp, PyObject *v, PyObject *w) { @@ -2609,9 +2781,11 @@ static PyMappingMethods dict_as_mapping = { dict_ass_sub, /*mp_ass_subscript*/ }; -PyObject * -PyDict_Keys(PyObject *dict) +static PyObject * +keys_lock_held(PyObject *dict) { + ASSERT_DICT_LOCKED(dict); + if (dict == NULL || !PyDict_Check(dict)) { PyErr_BadInternalCall(); return NULL; @@ -2646,8 +2820,21 @@ PyDict_Keys(PyObject *dict) } PyObject * -PyDict_Values(PyObject *dict) +PyDict_Keys(PyObject *dict) +{ + PyObject *res; + Py_BEGIN_CRITICAL_SECTION(dict); + res = keys_lock_held(dict); + Py_END_CRITICAL_SECTION(); + + return res; +} + +static PyObject * +values_lock_held(PyObject *dict) { + ASSERT_DICT_LOCKED(dict); + if (dict == NULL || !PyDict_Check(dict)) { PyErr_BadInternalCall(); return NULL; @@ -2682,8 +2869,20 @@ PyDict_Values(PyObject *dict) } PyObject * -PyDict_Items(PyObject *dict) +PyDict_Values(PyObject *dict) +{ + PyObject *res; + Py_BEGIN_CRITICAL_SECTION(dict); + res = values_lock_held(dict); + Py_END_CRITICAL_SECTION(); + return res; +} + +static PyObject * +items_lock_held(PyObject *dict) { + ASSERT_DICT_LOCKED(dict); + if (dict == NULL || !PyDict_Check(dict)) { PyErr_BadInternalCall(); return NULL; @@ -2732,6 +2931,17 @@ PyDict_Items(PyObject *dict) return v; } +PyObject * +PyDict_Items(PyObject *dict) +{ + PyObject *res; + Py_BEGIN_CRITICAL_SECTION(dict); + res = items_lock_held(dict); + Py_END_CRITICAL_SECTION(); + + return res; +} + /*[clinic input] @classmethod dict.fromkeys @@ -2810,8 +3020,8 @@ dict_update(PyObject *self, PyObject *args, PyObject *kwds) producing iterable objects of length 2. */ -int -PyDict_MergeFromSeq2(PyObject *d, PyObject *seq2, int override) +static int +merge_from_seq2_lock_held(PyObject *d, PyObject *seq2, int override) { PyObject *it; /* iter(seq2) */ Py_ssize_t i; /* index into seq2 of current element */ @@ -2863,14 +3073,14 @@ PyDict_MergeFromSeq2(PyObject *d, PyObject *seq2, int override) Py_INCREF(key); Py_INCREF(value); if (override) { - if (PyDict_SetItem(d, key, value) < 0) { + if (setitem_lock_held((PyDictObject *)d, key, value) < 0) { Py_DECREF(key); Py_DECREF(value); goto Fail; } } else { - if (PyDict_SetDefault(d, key, value) == NULL) { + if (dict_setdefault_ref_lock_held(d, key, value, NULL, 0) < 0) { Py_DECREF(key); Py_DECREF(value); goto Fail; @@ -2895,6 +3105,117 @@ PyDict_MergeFromSeq2(PyObject *d, PyObject *seq2, int override) return Py_SAFE_DOWNCAST(i, Py_ssize_t, int); } +int +PyDict_MergeFromSeq2(PyObject *d, PyObject *seq2, int override) +{ + int res; + Py_BEGIN_CRITICAL_SECTION(d); + res = merge_from_seq2_lock_held(d, seq2, override); + Py_END_CRITICAL_SECTION(); + + return res; +} + +static int +dict_dict_merge(PyInterpreterState *interp, PyDictObject *mp, PyDictObject *other, int override) +{ + if (other == mp || other->ma_used == 0) + /* a.update(a) or a.update({}); nothing to do */ + return 0; + if (mp->ma_used == 0) { + /* Since the target dict is empty, PyDict_GetItem() + * always returns NULL. Setting override to 1 + * skips the unnecessary test. + */ + override = 1; + PyDictKeysObject *okeys = other->ma_keys; + + // If other is clean, combined, and just allocated, just clone it. + if (other->ma_values == NULL && + other->ma_used == okeys->dk_nentries && + (DK_LOG_SIZE(okeys) == PyDict_LOG_MINSIZE || + USABLE_FRACTION(DK_SIZE(okeys)/2) < other->ma_used)) { + uint64_t new_version = _PyDict_NotifyEvent( + interp, PyDict_EVENT_CLONED, mp, (PyObject *)other, NULL); + PyDictKeysObject *keys = clone_combined_dict_keys(other); + if (keys == NULL) + return -1; + + dictkeys_decref(interp, mp->ma_keys); + mp->ma_keys = keys; + if (mp->ma_values != NULL) { + free_values(mp->ma_values); + mp->ma_values = NULL; + } + + mp->ma_used = other->ma_used; + mp->ma_version_tag = new_version; + ASSERT_CONSISTENT(mp); + + if (_PyObject_GC_IS_TRACKED(other) && !_PyObject_GC_IS_TRACKED(mp)) { + /* Maintain tracking. */ + _PyObject_GC_TRACK(mp); + } + + return 0; + } + } + /* Do one big resize at the start, rather than + * incrementally resizing as we insert new items. Expect + * that there will be no (or few) overlapping keys. + */ + if (USABLE_FRACTION(DK_SIZE(mp->ma_keys)) < other->ma_used) { + int unicode = DK_IS_UNICODE(other->ma_keys); + if (dictresize(interp, mp, + estimate_log2_keysize(mp->ma_used + other->ma_used), + unicode)) { + return -1; + } + } + + Py_ssize_t orig_size = other->ma_keys->dk_nentries; + Py_ssize_t pos = 0; + Py_hash_t hash; + PyObject *key, *value; + + while (_PyDict_Next((PyObject*)other, &pos, &key, &value, &hash)) { + int err = 0; + Py_INCREF(key); + Py_INCREF(value); + if (override == 1) { + err = insertdict(interp, mp, + Py_NewRef(key), hash, Py_NewRef(value)); + } + else { + err = contains_known_hash_lock_held(mp, key, hash); + if (err == 0) { + err = insertdict(interp, mp, + Py_NewRef(key), hash, Py_NewRef(value)); + } + else if (err > 0) { + if (override != 0) { + _PyErr_SetKeyError(key); + Py_DECREF(value); + Py_DECREF(key); + return -1; + } + err = 0; + } + } + Py_DECREF(value); + Py_DECREF(key); + if (err != 0) + return -1; + + if (orig_size != other->ma_keys->dk_nentries) { + PyErr_SetString(PyExc_RuntimeError, + "dict mutated during update"); + return -1; + } + } + return 0; +} + static int dict_merge(PyInterpreterState *interp, PyObject *a, PyObject *b, int override) { @@ -2912,127 +3233,44 @@ dict_merge(PyInterpreterState *interp, PyObject *a, PyObject *b, int override) return -1; } mp = (PyDictObject*)a; + int res = 0; if (PyDict_Check(b) && (Py_TYPE(b)->tp_iter == dict_iter)) { other = (PyDictObject*)b; - if (other == mp || other->ma_used == 0) - /* a.update(a) or a.update({}); nothing to do */ - return 0; - if (mp->ma_used == 0) { - /* Since the target dict is empty, PyDict_GetItem() - * always returns NULL. Setting override to 1 - * skips the unnecessary test. - */ - override = 1; - PyDictKeysObject *okeys = other->ma_keys; - - // If other is clean, combined, and just allocated, just clone it. - if (other->ma_values == NULL && - other->ma_used == okeys->dk_nentries && - (DK_LOG_SIZE(okeys) == PyDict_LOG_MINSIZE || - USABLE_FRACTION(DK_SIZE(okeys)/2) < other->ma_used)) { - uint64_t new_version = _PyDict_NotifyEvent( - interp, PyDict_EVENT_CLONED, mp, b, NULL); - PyDictKeysObject *keys = clone_combined_dict_keys(other); - if (keys == NULL) { - return -1; - } - - dictkeys_decref(interp, mp->ma_keys); - mp->ma_keys = keys; - if (mp->ma_values != NULL) { - free_values(mp->ma_values); - mp->ma_values = NULL; - } - - mp->ma_used = other->ma_used; - mp->ma_version_tag = new_version; - ASSERT_CONSISTENT(mp); - - if (_PyObject_GC_IS_TRACKED(other) && !_PyObject_GC_IS_TRACKED(mp)) { - /* Maintain tracking. */ - _PyObject_GC_TRACK(mp); - } - - return 0; - } - } - /* Do one big resize at the start, rather than - * incrementally resizing as we insert new items. Expect - * that there will be no (or few) overlapping keys. - */ - if (USABLE_FRACTION(DK_SIZE(mp->ma_keys)) < other->ma_used) { - int unicode = DK_IS_UNICODE(other->ma_keys); - if (dictresize(interp, mp, - estimate_log2_keysize(mp->ma_used + other->ma_used), - unicode)) { - return -1; - } - } - - Py_ssize_t orig_size = other->ma_keys->dk_nentries; - Py_ssize_t pos = 0; - Py_hash_t hash; - PyObject *key, *value; - - while (_PyDict_Next((PyObject*)other, &pos, &key, &value, &hash)) { - int err = 0; - Py_INCREF(key); - Py_INCREF(value); - if (override == 1) { - err = insertdict(interp, mp, - Py_NewRef(key), hash, Py_NewRef(value)); - } - else { - err = _PyDict_Contains_KnownHash(a, key, hash); - if (err == 0) { - err = insertdict(interp, mp, - Py_NewRef(key), hash, Py_NewRef(value)); - } - else if (err > 0) { - if (override != 0) { - _PyErr_SetKeyError(key); - Py_DECREF(value); - Py_DECREF(key); - return -1; - } - err = 0; - } - } - Py_DECREF(value); - Py_DECREF(key); - if (err != 0) - return -1; - - if (orig_size != other->ma_keys->dk_nentries) { - PyErr_SetString(PyExc_RuntimeError, - "dict mutated during update"); - return -1; - } - } + int res; + Py_BEGIN_CRITICAL_SECTION2(a, b); + res = dict_dict_merge(interp, (PyDictObject *)a, other, override); + ASSERT_CONSISTENT(a); + Py_END_CRITICAL_SECTION2(); + return res; } else { /* Do it the generic, slower way */ + Py_BEGIN_CRITICAL_SECTION(a); PyObject *keys = PyMapping_Keys(b); PyObject *iter; PyObject *key, *value; int status; - if (keys == NULL) + if (keys == NULL) { /* Docstring says this is equivalent to E.keys() so * if E doesn't have a .keys() method we want * AttributeError to percolate up. Might as well * do the same for any other error. */ - return -1; + res = -1; + goto slow_exit; + } iter = PyObject_GetIter(keys); Py_DECREF(keys); - if (iter == NULL) - return -1; + if (iter == NULL) { + res = -1; + goto slow_exit; + } for (key = PyIter_Next(iter); key; key = PyIter_Next(iter)) { if (override != 1) { - status = PyDict_Contains(a, key); + status = contains_lock_held(mp, key); if (status != 0) { if (status > 0) { if (override == 0) { @@ -3043,30 +3281,39 @@ dict_merge(PyInterpreterState *interp, PyObject *a, PyObject *b, int override) } Py_DECREF(key); Py_DECREF(iter); - return -1; + res = -1; + goto slow_exit; } } value = PyObject_GetItem(b, key); if (value == NULL) { Py_DECREF(iter); Py_DECREF(key); - return -1; + res = -1; + goto slow_exit; } - status = PyDict_SetItem(a, key, value); + status = setitem_lock_held(mp, key, value); Py_DECREF(key); Py_DECREF(value); if (status < 0) { Py_DECREF(iter); + res = -1; + goto slow_exit; return -1; } } Py_DECREF(iter); - if (PyErr_Occurred()) + if (PyErr_Occurred()) { /* Iterator completed, via error */ - return -1; + res = -1; + goto slow_exit; + } + +slow_exit: + ASSERT_CONSISTENT(a); + Py_END_CRITICAL_SECTION(); + return res; } - ASSERT_CONSISTENT(a); - return 0; } int @@ -3104,17 +3351,14 @@ dict_copy_impl(PyDictObject *self) return PyDict_Copy((PyObject *)self); } -PyObject * -PyDict_Copy(PyObject *o) +static PyObject * +copy_lock_held(PyObject *o) { PyObject *copy; PyDictObject *mp; PyInterpreterState *interp = _PyInterpreterState_GET(); - if (o == NULL || !PyDict_Check(o)) { - PyErr_BadInternalCall(); - return NULL; - } + ASSERT_DICT_LOCKED(o); mp = (PyDictObject *)o; if (mp->ma_used == 0) { @@ -3197,6 +3441,23 @@ PyDict_Copy(PyObject *o) return NULL; } +PyObject * +PyDict_Copy(PyObject *o) +{ + if (o == NULL || !PyDict_Check(o)) { + PyErr_BadInternalCall(); + return NULL; + } + + PyObject *res; + Py_BEGIN_CRITICAL_SECTION(o); + + res = copy_lock_held(o); + + Py_END_CRITICAL_SECTION(); + return res; +} + Py_ssize_t PyDict_Size(PyObject *mp) { @@ -3212,10 +3473,13 @@ PyDict_Size(PyObject *mp) * Uses only Py_EQ comparison. */ static int -dict_equal(PyDictObject *a, PyDictObject *b) +dict_equal_lock_held(PyDictObject *a, PyDictObject *b) { Py_ssize_t i; + ASSERT_DICT_LOCKED(a); + ASSERT_DICT_LOCKED(b); + if (a->ma_used != b->ma_used) /* can't be equal if # of entries differ */ return 0; @@ -3270,6 +3534,17 @@ dict_equal(PyDictObject *a, PyDictObject *b) return 1; } +static int +dict_equal(PyDictObject *a, PyDictObject *b) +{ + int res; + Py_BEGIN_CRITICAL_SECTION2(a, b); + res = dict_equal_lock_held(a, b); + Py_END_CRITICAL_SECTION2(); + + return res; +} + static PyObject * dict_richcompare(PyObject *v, PyObject *w, int op) { @@ -3293,6 +3568,7 @@ dict_richcompare(PyObject *v, PyObject *w, int op) /*[clinic input] @coexist +@critical_section dict.__contains__ key: object @@ -3302,8 +3578,8 @@ True if the dictionary has the specified key, else False. [clinic start generated code]*/ static PyObject * -dict___contains__(PyDictObject *self, PyObject *key) -/*[clinic end generated code: output=a3d03db709ed6e6b input=fe1cb42ad831e820]*/ +dict___contains___impl(PyDictObject *self, PyObject *key) +/*[clinic end generated code: output=1b314e6da7687dae input=bc76ec9c157cb81b]*/ { register PyDictObject *mp = self; Py_hash_t hash; @@ -3324,6 +3600,7 @@ dict___contains__(PyDictObject *self, PyObject *key) } /*[clinic input] +@critical_section dict.get key: object @@ -3335,7 +3612,7 @@ Return the value for key if key is in the dictionary, else default. static PyObject * dict_get_impl(PyDictObject *self, PyObject *key, PyObject *default_value) -/*[clinic end generated code: output=bba707729dee05bf input=279ddb5790b6b107]*/ +/*[clinic end generated code: output=bba707729dee05bf input=a631d3f18f584c60]*/ { PyObject *val = NULL; Py_hash_t hash; @@ -3356,7 +3633,7 @@ dict_get_impl(PyDictObject *self, PyObject *key, PyObject *default_value) } static int -dict_setdefault_ref(PyObject *d, PyObject *key, PyObject *default_value, +dict_setdefault_ref_lock_held(PyObject *d, PyObject *key, PyObject *default_value, PyObject **result, int incref_result) { PyDictObject *mp = (PyDictObject *)d; @@ -3364,6 +3641,8 @@ dict_setdefault_ref(PyObject *d, PyObject *key, PyObject *default_value, Py_hash_t hash; PyInterpreterState *interp = _PyInterpreterState_GET(); + ASSERT_DICT_LOCKED(d); + if (!PyDict_Check(d)) { PyErr_BadInternalCall(); if (result) { @@ -3490,18 +3769,25 @@ int PyDict_SetDefaultRef(PyObject *d, PyObject *key, PyObject *default_value, PyObject **result) { - return dict_setdefault_ref(d, key, default_value, result, 1); + int res; + Py_BEGIN_CRITICAL_SECTION(d); + res = dict_setdefault_ref_lock_held(d, key, default_value, result, 1); + Py_END_CRITICAL_SECTION(); + return res; } PyObject * PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj) { PyObject *result; - dict_setdefault_ref(d, key, defaultobj, &result, 0); + Py_BEGIN_CRITICAL_SECTION(d); + dict_setdefault_ref_lock_held(d, key, defaultobj, &result, 0); + Py_END_CRITICAL_SECTION(); return result; } /*[clinic input] +@critical_section dict.setdefault key: object @@ -3516,10 +3802,10 @@ Return the value for key if key is in the dictionary, else default. static PyObject * dict_setdefault_impl(PyDictObject *self, PyObject *key, PyObject *default_value) -/*[clinic end generated code: output=f8c1101ebf69e220 input=0f063756e815fd9d]*/ +/*[clinic end generated code: output=f8c1101ebf69e220 input=9237af9a0a224302]*/ { PyObject *val; - PyDict_SetDefaultRef((PyObject *)self, key, default_value, &val); + dict_setdefault_ref_lock_held((PyObject *)self, key, default_value, &val, 1); return val; } @@ -3559,6 +3845,7 @@ dict_pop_impl(PyDictObject *self, PyObject *key, PyObject *default_value) } /*[clinic input] +@critical_section dict.popitem Remove and return a (key, value) pair as a 2-tuple. @@ -3569,7 +3856,7 @@ Raises KeyError if the dict is empty. static PyObject * dict_popitem_impl(PyDictObject *self) -/*[clinic end generated code: output=e65fcb04420d230d input=1c38a49f21f64941]*/ +/*[clinic end generated code: output=e65fcb04420d230d input=ef28b4da5f0f762e]*/ { Py_ssize_t i, j; PyObject *res; @@ -3695,8 +3982,8 @@ dict_tp_clear(PyObject *op) static PyObject *dictiter_new(PyDictObject *, PyTypeObject *); -Py_ssize_t -_PyDict_SizeOf(PyDictObject *mp) +static Py_ssize_t +sizeof_lock_held(PyDictObject *mp) { size_t res = _PyObject_SIZE(Py_TYPE(mp)); if (mp->ma_values) { @@ -3711,6 +3998,17 @@ _PyDict_SizeOf(PyDictObject *mp) return (Py_ssize_t)res; } +Py_ssize_t +_PyDict_SizeOf(PyDictObject *mp) +{ + Py_ssize_t res; + Py_BEGIN_CRITICAL_SECTION(mp); + res = sizeof_lock_held(mp); + Py_END_CRITICAL_SECTION(); + + return res; +} + size_t _PyDict_KeysSize(PyDictKeysObject *keys) { @@ -3794,15 +4092,29 @@ static PyMethodDef mapp_methods[] = { {NULL, NULL} /* sentinel */ }; -/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */ -int -PyDict_Contains(PyObject *op, PyObject *key) +static int +contains_known_hash_lock_held(PyDictObject *mp, PyObject *key, Py_ssize_t hash) +{ + Py_ssize_t ix; + PyObject *value; + + ASSERT_DICT_LOCKED(mp); + + ix = _Py_dict_lookup(mp, key, hash, &value); + if (ix == DKIX_ERROR) + return -1; + return (ix != DKIX_EMPTY && value != NULL); +} + +static int +contains_lock_held(PyDictObject *mp, PyObject *key) { Py_hash_t hash; Py_ssize_t ix; - PyDictObject *mp = (PyDictObject *)op; PyObject *value; + ASSERT_DICT_LOCKED(mp); + if (!PyUnicode_CheckExact(key) || (hash = unicode_get_hash(key)) == -1) { hash = PyObject_Hash(key); if (hash == -1) @@ -3814,6 +4126,17 @@ PyDict_Contains(PyObject *op, PyObject *key) return (ix != DKIX_EMPTY && value != NULL); } +/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */ +int +PyDict_Contains(PyObject *op, PyObject *key) +{ + int res; + Py_BEGIN_CRITICAL_SECTION(op); + res = contains_lock_held((PyDictObject *)op, key); + Py_END_CRITICAL_SECTION(); + return res; +} + int PyDict_ContainsString(PyObject *op, const char *key) { @@ -4180,17 +4503,15 @@ static PyMethodDef dictiter_methods[] = { }; static PyObject* -dictiter_iternextkey(PyObject *self) +dictiter_iternextkey_lock_held(PyDictObject *d, PyObject *self) { dictiterobject *di = (dictiterobject *)self; PyObject *key; Py_ssize_t i; PyDictKeysObject *k; - PyDictObject *d = di->di_dict; - if (d == NULL) - return NULL; assert (PyDict_Check(d)); + ASSERT_DICT_LOCKED(d); if (di->di_used != d->ma_used) { PyErr_SetString(PyExc_RuntimeError, @@ -4248,6 +4569,23 @@ dictiter_iternextkey(PyObject *self) return NULL; } +static PyObject* +dictiter_iternextkey(PyObject *self) +{ + dictiterobject *di = (dictiterobject *)self; + PyDictObject *d = di->di_dict; + + if (d == NULL) + return NULL; + + PyObject *value; + Py_BEGIN_CRITICAL_SECTION(d); + value = dictiter_iternextkey_lock_held(d, self); + Py_END_CRITICAL_SECTION(); + + return value; +} + PyTypeObject PyDictIterKey_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) "dict_keyiterator", /* tp_name */ @@ -4282,16 +4620,14 @@ PyTypeObject PyDictIterKey_Type = { }; static PyObject * -dictiter_iternextvalue(PyObject *self) +dictiter_iternextvalue_lock_held(PyDictObject *d, PyObject *self) { dictiterobject *di = (dictiterobject *)self; PyObject *value; Py_ssize_t i; - PyDictObject *d = di->di_dict; - if (d == NULL) - return NULL; assert (PyDict_Check(d)); + ASSERT_DICT_LOCKED(d); if (di->di_used != d->ma_used) { PyErr_SetString(PyExc_RuntimeError, @@ -4348,6 +4684,23 @@ dictiter_iternextvalue(PyObject *self) return NULL; } +static PyObject * +dictiter_iternextvalue(PyObject *self) +{ + dictiterobject *di = (dictiterobject *)self; + PyDictObject *d = di->di_dict; + + if (d == NULL) + return NULL; + + PyObject *value; + Py_BEGIN_CRITICAL_SECTION(d); + value = dictiter_iternextvalue_lock_held(d, self); + Py_END_CRITICAL_SECTION(); + + return value; +} + PyTypeObject PyDictIterValue_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) "dict_valueiterator", /* tp_name */ @@ -4382,15 +4735,12 @@ PyTypeObject PyDictIterValue_Type = { }; static PyObject * -dictiter_iternextitem(PyObject *self) +dictiter_iternextitem_lock_held(PyDictObject *d, PyObject *self) { dictiterobject *di = (dictiterobject *)self; PyObject *key, *value, *result; Py_ssize_t i; - PyDictObject *d = di->di_dict; - if (d == NULL) - return NULL; assert (PyDict_Check(d)); if (di->di_used != d->ma_used) { @@ -4473,6 +4823,22 @@ dictiter_iternextitem(PyObject *self) return NULL; } +static PyObject * +dictiter_iternextitem(PyObject *self) +{ + dictiterobject *di = (dictiterobject *)self; + PyDictObject *d = di->di_dict; + + if (d == NULL) + return NULL; + + PyObject *item; + Py_BEGIN_CRITICAL_SECTION(d); + item = dictiter_iternextitem_lock_held(d, self); + Py_END_CRITICAL_SECTION(); + return item; +} + PyTypeObject PyDictIterItem_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) "dict_itemiterator", /* tp_name */ @@ -4510,15 +4876,12 @@ PyTypeObject PyDictIterItem_Type = { /* dictreviter */ static PyObject * -dictreviter_iternext(PyObject *self) +dictreviter_iter_PyDict_Next(PyDictObject *d, PyObject *self) { dictiterobject *di = (dictiterobject *)self; - PyDictObject *d = di->di_dict; - if (d == NULL) { - return NULL; - } assert (PyDict_Check(d)); + ASSERT_DICT_LOCKED(d); if (di->di_used != d->ma_used) { PyErr_SetString(PyExc_RuntimeError, @@ -4609,6 +4972,23 @@ dictreviter_iternext(PyObject *self) return NULL; } +static PyObject * +dictreviter_iternext(PyObject *self) +{ + dictiterobject *di = (dictiterobject *)self; + PyDictObject *d = di->di_dict; + + if (d == NULL) + return NULL; + + PyObject *value; + Py_BEGIN_CRITICAL_SECTION(d); + value = dictreviter_iter_PyDict_Next(d, self); + Py_END_CRITICAL_SECTION(); + + return value; +} + PyTypeObject PyDictRevIterKey_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) "dict_reversekeyiterator", @@ -5037,14 +5417,12 @@ dictviews_or(PyObject* self, PyObject *other) } static PyObject * -dictitems_xor(PyObject *self, PyObject *other) +dictitems_xor_lock_held(PyObject *d1, PyObject *d2) { - assert(PyDictItems_Check(self)); - assert(PyDictItems_Check(other)); - PyObject *d1 = (PyObject *)((_PyDictViewObject *)self)->dv_dict; - PyObject *d2 = (PyObject *)((_PyDictViewObject *)other)->dv_dict; + ASSERT_DICT_LOCKED(d1); + ASSERT_DICT_LOCKED(d2); - PyObject *temp_dict = PyDict_Copy(d1); + PyObject *temp_dict = copy_lock_held(d1); if (temp_dict == NULL) { return NULL; } @@ -5122,6 +5500,22 @@ dictitems_xor(PyObject *self, PyObject *other) return NULL; } +static PyObject * +dictitems_xor(PyObject *self, PyObject *other) +{ + assert(PyDictItems_Check(self)); + assert(PyDictItems_Check(other)); + PyObject *d1 = (PyObject *)((_PyDictViewObject *)self)->dv_dict; + PyObject *d2 = (PyObject *)((_PyDictViewObject *)other)->dv_dict; + + PyObject *res; + Py_BEGIN_CRITICAL_SECTION2(d1, d2); + res = dictitems_xor_lock_held(d1, d2); + Py_END_CRITICAL_SECTION2(); + + return res; +} + static PyObject* dictviews_xor(PyObject* self, PyObject *other) { diff --git a/Objects/odictobject.c b/Objects/odictobject.c index b5280c39e1be54..421bc52992d735 100644 --- a/Objects/odictobject.c +++ b/Objects/odictobject.c @@ -465,12 +465,13 @@ Potential Optimizations */ #include "Python.h" -#include "pycore_call.h" // _PyObject_CallNoArgs() -#include "pycore_ceval.h" // _PyEval_GetBuiltin() -#include "pycore_dict.h" // _Py_dict_lookup() -#include "pycore_object.h" // _PyObject_GC_UNTRACK() -#include "pycore_pyerrors.h" // _PyErr_ChainExceptions1() -#include // offsetof() +#include "pycore_call.h" // _PyObject_CallNoArgs() +#include "pycore_ceval.h" // _PyEval_GetBuiltin() +#include "pycore_critical_section.h" //_Py_BEGIN_CRITICAL_SECTION +#include "pycore_dict.h" // _Py_dict_lookup() +#include "pycore_object.h" // _PyObject_GC_UNTRACK() +#include "pycore_pyerrors.h" // _PyErr_ChainExceptions1() +#include // offsetof() #include "clinic/odictobject.c.h" @@ -1039,6 +1040,8 @@ _odict_popkey_hash(PyObject *od, PyObject *key, PyObject *failobj, { PyObject *value = NULL; + Py_BEGIN_CRITICAL_SECTION(od); + _ODictNode *node = _odict_find_node_hash((PyODictObject *)od, key, hash); if (node != NULL) { /* Pop the node first to avoid a possible dict resize (due to @@ -1046,7 +1049,7 @@ _odict_popkey_hash(PyObject *od, PyObject *key, PyObject *failobj, resolution. */ int res = _odict_clear_node((PyODictObject *)od, node, key, hash); if (res < 0) { - return NULL; + goto done; } /* Now delete the value from the dict. */ if (_PyDict_Pop_KnownHash((PyDictObject *)od, key, hash, @@ -1063,6 +1066,8 @@ _odict_popkey_hash(PyObject *od, PyObject *key, PyObject *failobj, PyErr_SetObject(PyExc_KeyError, key); } } + Py_END_CRITICAL_SECTION(); +done: return value; } diff --git a/Objects/setobject.c b/Objects/setobject.c index 93de8e84f2ddf9..3acf2a7a74890b 100644 --- a/Objects/setobject.c +++ b/Objects/setobject.c @@ -32,13 +32,14 @@ */ #include "Python.h" -#include "pycore_ceval.h" // _PyEval_GetBuiltin() -#include "pycore_dict.h" // _PyDict_Contains_KnownHash() -#include "pycore_modsupport.h" // _PyArg_NoKwnames() -#include "pycore_object.h" // _PyObject_GC_UNTRACK() -#include "pycore_pyerrors.h" // _PyErr_SetKeyError() -#include "pycore_setobject.h" // _PySet_NextEntry() definition -#include // offsetof() +#include "pycore_ceval.h" // _PyEval_GetBuiltin() +#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION, Py_END_CRITICAL_SECTION +#include "pycore_dict.h" // _PyDict_Contains_KnownHash() +#include "pycore_modsupport.h" // _PyArg_NoKwnames() +#include "pycore_object.h" // _PyObject_GC_UNTRACK() +#include "pycore_pyerrors.h" // _PyErr_SetKeyError() +#include "pycore_setobject.h" // _PySet_NextEntry() definition +#include // offsetof() /* Object used as dummy key to fill deleted entries */ static PyObject _dummy_struct; @@ -903,11 +904,17 @@ set_update_internal(PySetObject *so, PyObject *other) if (set_table_resize(so, (so->used + dictsize)*2) != 0) return -1; } + int err = 0; + Py_BEGIN_CRITICAL_SECTION(other); while (_PyDict_Next(other, &pos, &key, &value, &hash)) { - if (set_add_entry(so, key, hash)) - return -1; + if (set_add_entry(so, key, hash)) { + err = -1; + goto exit; + } } - return 0; +exit: + Py_END_CRITICAL_SECTION(); + return err; } it = PyObject_GetIter(other); @@ -1620,6 +1627,33 @@ set_isub(PySetObject *so, PyObject *other) return Py_NewRef(so); } +static PyObject * +set_symmetric_difference_update_dict(PySetObject *so, PyObject *other) +{ + PyObject *key; + Py_ssize_t pos = 0; + Py_hash_t hash; + PyObject *value; + int rv; + + while (_PyDict_Next(other, &pos, &key, &value, &hash)) { + Py_INCREF(key); + rv = set_discard_entry(so, key, hash); + if (rv < 0) { + Py_DECREF(key); + return NULL; + } + if (rv == DISCARD_NOTFOUND) { + if (set_add_entry(so, key, hash)) { + Py_DECREF(key); + return NULL; + } + } + Py_DECREF(key); + } + Py_RETURN_NONE; +} + static PyObject * set_symmetric_difference_update(PySetObject *so, PyObject *other) { @@ -1634,23 +1668,13 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other) return set_clear(so, NULL); if (PyDict_CheckExact(other)) { - PyObject *value; - while (_PyDict_Next(other, &pos, &key, &value, &hash)) { - Py_INCREF(key); - rv = set_discard_entry(so, key, hash); - if (rv < 0) { - Py_DECREF(key); - return NULL; - } - if (rv == DISCARD_NOTFOUND) { - if (set_add_entry(so, key, hash)) { - Py_DECREF(key); - return NULL; - } - } - Py_DECREF(key); - } - Py_RETURN_NONE; + PyObject *res; + + Py_BEGIN_CRITICAL_SECTION(other); + res = set_symmetric_difference_update_dict(so, other); + Py_END_CRITICAL_SECTION(); + + return res; } if (PyAnySet_Check(other)) {