From af9a8b415ea1f9f10f2b91189b7cfd9759770652 Mon Sep 17 00:00:00 2001 From: Neil Schemenauer Date: Fri, 20 Dec 2024 16:23:16 -0800 Subject: [PATCH 1/4] Add 'context' parameter to Thread. By default, inherit the context from the thread calling `Thread.start()`. --- Doc/library/threading.rst | 10 ++++++++- Lib/test/test_context.py | 43 +++++++++++++++++++++++++++++++++++++++ Lib/test/test_decimal.py | 4 ++-- Lib/threading.py | 26 ++++++++++++++++++++--- 4 files changed, 77 insertions(+), 6 deletions(-) diff --git a/Doc/library/threading.rst b/Doc/library/threading.rst index f183f3f535c4cb..f61f7208effe17 100644 --- a/Doc/library/threading.rst +++ b/Doc/library/threading.rst @@ -334,7 +334,7 @@ since it is impossible to detect the termination of alien threads. .. class:: Thread(group=None, target=None, name=None, args=(), kwargs={}, *, \ - daemon=None) + daemon=None, context="inherit") This constructor should always be called with keyword arguments. Arguments are: @@ -359,6 +359,10 @@ since it is impossible to detect the termination of alien threads. If ``None`` (the default), the daemonic property is inherited from the current thread. + *context* is the `contextvars.Context` value to use while running the thread. + The default is to inherit the context of the caller of :meth:`~Thread.start`. + If set to ``None``, the context will be empty. + If the subclass overrides the constructor, it must make sure to invoke the base class constructor (``Thread.__init__()``) before doing anything else to the thread. @@ -369,6 +373,10 @@ since it is impossible to detect the termination of alien threads. .. versionchanged:: 3.10 Use the *target* name if *name* argument is omitted. + .. versionchanged:: 3.14 + Added the *context* parameter. Previously threads always ran with an empty + context. + .. method:: start() Start the thread's activity. diff --git a/Lib/test/test_context.py b/Lib/test/test_context.py index 82d1797ab3b79e..5be25655e75569 100644 --- a/Lib/test/test_context.py +++ b/Lib/test/test_context.py @@ -383,6 +383,49 @@ def sub(num): tp.shutdown() self.assertEqual(results, list(range(10))) + @isolated_context + @threading_helper.requires_working_threading() + def test_context_thread_inherit(self): + import threading + + cvar = contextvars.ContextVar('cvar') + + # By default, the context of the caller is inheritied + def run_inherit(): + self.assertEqual(cvar.get(), 1) + + cvar.set(1) + thread = threading.Thread(target=run_inherit) + thread.start() + thread.join() + + # If context=None is passed, the thread has an empty context + def run_empty(): + with self.assertRaises(LookupError): + cvar.get() + + thread = threading.Thread(target=run_empty, context=None) + thread.start() + thread.join() + + # An explicit Context value can also be passed + custom_ctx = contextvars.Context() + custom_var = None + + def setup_context(): + nonlocal custom_var + custom_var = contextvars.ContextVar('custom') + custom_var.set(2) + + custom_ctx.run(setup_context) + + def run_custom(): + self.assertEqual(custom_var.get(), 2) + + thread = threading.Thread(target=run_custom, context=custom_ctx) + thread.start() + thread.join() + # HAMT Tests diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index bc6c6427740949..3d14f273e1597a 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -1725,8 +1725,8 @@ def test_threading(self): self.finish1 = threading.Event() self.finish2 = threading.Event() - th1 = threading.Thread(target=thfunc1, args=(self,)) - th2 = threading.Thread(target=thfunc2, args=(self,)) + th1 = threading.Thread(target=thfunc1, args=(self,), context=None) + th2 = threading.Thread(target=thfunc2, args=(self,), context=None) th1.start() th2.start() diff --git a/Lib/threading.py b/Lib/threading.py index 78e591124278fc..b2e5f573ceb96f 100644 --- a/Lib/threading.py +++ b/Lib/threading.py @@ -4,6 +4,8 @@ import sys as _sys import _thread import warnings +import contextvars as _contextvars + from time import monotonic as _time from _weakrefset import WeakSet @@ -871,7 +873,7 @@ class Thread: _initialized = False def __init__(self, group=None, target=None, name=None, - args=(), kwargs=None, *, daemon=None): + args=(), kwargs=None, *, daemon=None, context='inherit'): """This constructor should always be called with keyword arguments. Arguments are: *group* should be None; reserved for future extension when a ThreadGroup @@ -888,6 +890,10 @@ class is implemented. *kwargs* is a dictionary of keyword arguments for the target invocation. Defaults to {}. + *context* is the contextvars.Context value to use for the thread. The default + is to inherit the context of the caller. Set to None to start with an empty + context. + If a subclass overrides the constructor, it must make sure to invoke the base class constructor (Thread.__init__()) before doing anything else to the thread. @@ -917,6 +923,7 @@ class is implemented. self._daemonic = daemon else: self._daemonic = current_thread().daemon + self._context = context self._ident = None if _HAVE_THREAD_NATIVE_ID: self._native_id = None @@ -972,9 +979,15 @@ def start(self): with _active_limbo_lock: _limbo[self] = self + + if self._context == 'inherit': + # No context provided, inherit the context of the caller. + self._context = _contextvars.copy_context() + try: # Start joinable thread - _start_joinable_thread(self._bootstrap, handle=self._handle, + _start_joinable_thread(self._bootstrap, + handle=self._handle, daemon=self.daemon) except Exception: with _active_limbo_lock: @@ -1050,8 +1063,15 @@ def _bootstrap_inner(self): if _profile_hook: _sys.setprofile(_profile_hook) + try: - self.run() + if self._context is None: + # Run with empty context, matching behaviour of + # threading.local and older versions of Python. + self.run() + else: + # Run with the provided or the inherited context. + self._context.run(self.run) except: self._invoke_excepthook(self) finally: From 49cdd63683c1ead005a0dd57bd7f93994a433b17 Mon Sep 17 00:00:00 2001 From: Neil Schemenauer Date: Sat, 28 Dec 2024 13:00:59 -0800 Subject: [PATCH 2/4] Make _contextvars a builtin module. --- Makefile.pre.in | 1 + Modules/Setup | 1 - Modules/Setup.stdlib.in | 1 - Modules/config.c.in | 4 ++++ PCbuild/pythoncore.vcxproj | 2 +- PCbuild/pythoncore.vcxproj.filters | 6 +++--- Modules/_contextvarsmodule.c => Python/_contextvars.c | 2 +- .../clinic/_contextvars.c.h | 0 configure.ac | 1 - 9 files changed, 10 insertions(+), 8 deletions(-) rename Modules/_contextvarsmodule.c => Python/_contextvars.c (97%) rename Modules/clinic/_contextvarsmodule.c.h => Python/clinic/_contextvars.c.h (100%) diff --git a/Makefile.pre.in b/Makefile.pre.in index 67acf0fc520087..18484a42abb6e8 100644 --- a/Makefile.pre.in +++ b/Makefile.pre.in @@ -420,6 +420,7 @@ PARSER_HEADERS= \ # Python PYTHON_OBJS= \ + Python/_contextvars.o \ Python/_warnings.o \ Python/Python-ast.o \ Python/Python-tokenize.o \ diff --git a/Modules/Setup b/Modules/Setup index ddf39e0b966610..e01c7bb1a8a45e 100644 --- a/Modules/Setup +++ b/Modules/Setup @@ -132,7 +132,6 @@ PYTHONPATH=$(COREPYTHONPATH) #_asyncio _asynciomodule.c #_bisect _bisectmodule.c -#_contextvars _contextvarsmodule.c #_csv _csv.c #_datetime _datetimemodule.c #_decimal _decimal/_decimal.c diff --git a/Modules/Setup.stdlib.in b/Modules/Setup.stdlib.in index 52c0f883d383db..189c30c558336e 100644 --- a/Modules/Setup.stdlib.in +++ b/Modules/Setup.stdlib.in @@ -31,7 +31,6 @@ @MODULE_ARRAY_TRUE@array arraymodule.c @MODULE__ASYNCIO_TRUE@_asyncio _asynciomodule.c @MODULE__BISECT_TRUE@_bisect _bisectmodule.c -@MODULE__CONTEXTVARS_TRUE@_contextvars _contextvarsmodule.c @MODULE__CSV_TRUE@_csv _csv.c @MODULE__HEAPQ_TRUE@_heapq _heapqmodule.c @MODULE__JSON_TRUE@_json _json.c diff --git a/Modules/config.c.in b/Modules/config.c.in index 53b4fb285498d0..41e42228c1ffd1 100644 --- a/Modules/config.c.in +++ b/Modules/config.c.in @@ -29,6 +29,7 @@ extern PyObject* PyInit__imp(void); extern PyObject* PyInit_gc(void); extern PyObject* PyInit__ast(void); extern PyObject* PyInit__tokenize(void); +extern PyObject* PyInit__contextvars(void); extern PyObject* _PyWarnings_Init(void); extern PyObject* PyInit__string(void); @@ -55,6 +56,9 @@ struct _inittab _PyImport_Inittab[] = { /* This lives in gcmodule.c */ {"gc", PyInit_gc}, + /* This lives in Python/_contextvars.c */ + {"_contextvars", PyInit__contextvars}, + /* This lives in _warnings.c */ {"_warnings", _PyWarnings_Init}, diff --git a/PCbuild/pythoncore.vcxproj b/PCbuild/pythoncore.vcxproj index 9ebf58ae8a9bc4..ef6dbf9f8e4222 100644 --- a/PCbuild/pythoncore.vcxproj +++ b/PCbuild/pythoncore.vcxproj @@ -423,7 +423,6 @@ - @@ -570,6 +569,7 @@ + diff --git a/PCbuild/pythoncore.vcxproj.filters b/PCbuild/pythoncore.vcxproj.filters index 6c76a6ab592a84..b661aad2019454 100644 --- a/PCbuild/pythoncore.vcxproj.filters +++ b/PCbuild/pythoncore.vcxproj.filters @@ -1262,6 +1262,9 @@ PC + + Python + Python @@ -1526,9 +1529,6 @@ Objects - - Modules - Modules\zlib diff --git a/Modules/_contextvarsmodule.c b/Python/_contextvars.c similarity index 97% rename from Modules/_contextvarsmodule.c rename to Python/_contextvars.c index 3f96f07909b69a..0f8b8004c1af22 100644 --- a/Modules/_contextvarsmodule.c +++ b/Python/_contextvars.c @@ -1,6 +1,6 @@ #include "Python.h" -#include "clinic/_contextvarsmodule.c.h" +#include "clinic/_contextvars.c.h" /*[clinic input] module _contextvars diff --git a/Modules/clinic/_contextvarsmodule.c.h b/Python/clinic/_contextvars.c.h similarity index 100% rename from Modules/clinic/_contextvarsmodule.c.h rename to Python/clinic/_contextvars.c.h diff --git a/configure.ac b/configure.ac index 9e131ed1a2dc98..ada76ed4ef769e 100644 --- a/configure.ac +++ b/configure.ac @@ -7703,7 +7703,6 @@ dnl always enabled extension modules PY_STDLIB_MOD_SIMPLE([array]) PY_STDLIB_MOD_SIMPLE([_asyncio]) PY_STDLIB_MOD_SIMPLE([_bisect]) -PY_STDLIB_MOD_SIMPLE([_contextvars]) PY_STDLIB_MOD_SIMPLE([_csv]) PY_STDLIB_MOD_SIMPLE([_heapq]) PY_STDLIB_MOD_SIMPLE([_json]) From da96aa86cb4c9ec72f39a3857d00c2c9b10faa43 Mon Sep 17 00:00:00 2001 From: Neil Schemenauer Date: Fri, 27 Dec 2024 13:29:48 -0800 Subject: [PATCH 3/4] Add locking to warnings.py. Expose the mutex from _warnings.c and hold it when mutating the filters list or the filters version. --- Include/internal/pycore_warnings.h | 2 +- Lib/test/test_warnings/__init__.py | 2 +- Lib/warnings.py | 170 +++++++++++++++++------------ Python/_warnings.c | 126 ++++++++++++++------- Python/clinic/_warnings.c.h | 50 +++++++-- 5 files changed, 231 insertions(+), 119 deletions(-) diff --git a/Include/internal/pycore_warnings.h b/Include/internal/pycore_warnings.h index f9f6559312f4ef..672228cd6fbd19 100644 --- a/Include/internal/pycore_warnings.h +++ b/Include/internal/pycore_warnings.h @@ -14,7 +14,7 @@ struct _warnings_runtime_state { PyObject *filters; /* List */ PyObject *once_registry; /* Dict */ PyObject *default_action; /* String */ - PyMutex mutex; + _PyRecursiveMutex lock; long filters_version; }; diff --git a/Lib/test/test_warnings/__init__.py b/Lib/test/test_warnings/__init__.py index 4e3c877896f295..4bd164b8a9a82b 100644 --- a/Lib/test/test_warnings/__init__.py +++ b/Lib/test/test_warnings/__init__.py @@ -1521,7 +1521,7 @@ def test_late_resource_warning(self): self.assertTrue(err.startswith(expected), ascii(err)) -class DeprecatedTests(unittest.TestCase): +class DeprecatedTests(PyPublicAPITests): def test_dunder_deprecated(self): @deprecated("A will go away soon") class A: diff --git a/Lib/warnings.py b/Lib/warnings.py index e83cde37ab2d1a..d679974a46c658 100644 --- a/Lib/warnings.py +++ b/Lib/warnings.py @@ -185,24 +185,32 @@ def simplefilter(action, category=Warning, lineno=0, append=False): raise ValueError("lineno must be an int >= 0") _add_filter(action, None, category, None, lineno, append=append) +def _filters_mutated(): + # Even though this function is part of the public API, it's used by + # a fair amount of user code. + with _lock: + _filters_mutated_lock_held() + def _add_filter(*item, append): - # Remove possible duplicate filters, so new one will be placed - # in correct place. If append=True and duplicate exists, do nothing. - if not append: - try: - filters.remove(item) - except ValueError: - pass - filters.insert(0, item) - else: - if item not in filters: - filters.append(item) - _filters_mutated() + with _lock: + if not append: + # Remove possible duplicate filters, so new one will be placed + # in correct place. If append=True and duplicate exists, do nothing. + try: + filters.remove(item) + except ValueError: + pass + filters.insert(0, item) + else: + if item not in filters: + filters.append(item) + _filters_mutated_lock_held() def resetwarnings(): """Clear the list of warning filters, so that no filters are active.""" - filters[:] = [] - _filters_mutated() + with _lock: + filters[:] = [] + _filters_mutated_lock_held() class _OptionError(Exception): """Exception used by option processing helpers.""" @@ -353,11 +361,6 @@ def warn_explicit(message, category, filename, lineno, module = filename or "" if module[-3:].lower() == ".py": module = module[:-3] # XXX What about leading pathname? - if registry is None: - registry = {} - if registry.get('version', 0) != _filters_version: - registry.clear() - registry['version'] = _filters_version if isinstance(message, Warning): text = str(message) category = message.__class__ @@ -365,52 +368,59 @@ def warn_explicit(message, category, filename, lineno, text = message message = category(message) key = (text, category, lineno) - # Quick test for common case - if registry.get(key): - return - # Search the filters - for item in filters: - action, msg, cat, mod, ln = item - if ((msg is None or msg.match(text)) and - issubclass(category, cat) and - (mod is None or mod.match(module)) and - (ln == 0 or lineno == ln)): - break - else: - action = defaultaction - # Early exit actions - if action == "ignore": - return + with _lock: + if registry is None: + registry = {} + if registry.get('version', 0) != _filters_version: + registry.clear() + registry['version'] = _filters_version + # Quick test for common case + if registry.get(key): + return + # Search the filters + for item in filters: + action, msg, cat, mod, ln = item + if ((msg is None or msg.match(text)) and + issubclass(category, cat) and + (mod is None or mod.match(module)) and + (ln == 0 or lineno == ln)): + break + else: + action = defaultaction + # Early exit actions + if action == "ignore": + return + + if action == "error": + raise message + # Other actions + if action == "once": + registry[key] = 1 + oncekey = (text, category) + if onceregistry.get(oncekey): + return + onceregistry[oncekey] = 1 + elif action in {"always", "all"}: + pass + elif action == "module": + registry[key] = 1 + altkey = (text, category, 0) + if registry.get(altkey): + return + registry[altkey] = 1 + elif action == "default": + registry[key] = 1 + else: + # Unrecognized actions are errors + raise RuntimeError( + "Unrecognized action (%r) in warnings.filters:\n %s" % + (action, item)) # Prime the linecache for formatting, in case the # "file" is actually in a zipfile or something. import linecache linecache.getlines(filename, module_globals) - if action == "error": - raise message - # Other actions - if action == "once": - registry[key] = 1 - oncekey = (text, category) - if onceregistry.get(oncekey): - return - onceregistry[oncekey] = 1 - elif action in {"always", "all"}: - pass - elif action == "module": - registry[key] = 1 - altkey = (text, category, 0) - if registry.get(altkey): - return - registry[altkey] = 1 - elif action == "default": - registry[key] = 1 - else: - # Unrecognized actions are errors - raise RuntimeError( - "Unrecognized action (%r) in warnings.filters:\n %s" % - (action, item)) # Print message and context msg = WarningMessage(message, category, filename, lineno, source) _showwarnmsg(msg) @@ -488,11 +498,12 @@ def __enter__(self): if self._entered: raise RuntimeError("Cannot enter %r twice" % self) self._entered = True - self._filters = self._module.filters - self._module.filters = self._filters[:] - self._module._filters_mutated() - self._showwarning = self._module.showwarning - self._showwarnmsg_impl = self._module._showwarnmsg_impl + with _lock: + self._filters = self._module.filters + self._module.filters = self._filters[:] + self._module._filters_mutated_lock_held() + self._showwarning = self._module.showwarning + self._showwarnmsg_impl = self._module._showwarnmsg_impl if self._filter is not None: simplefilter(*self._filter) if self._record: @@ -508,10 +519,11 @@ def __enter__(self): def __exit__(self, *exc_info): if not self._entered: raise RuntimeError("Cannot exit %r without entering first" % self) - self._module.filters = self._filters - self._module._filters_mutated() - self._module.showwarning = self._showwarning - self._module._showwarnmsg_impl = self._showwarnmsg_impl + with _lock: + self._module.filters = self._filters + self._module._filters_mutated_lock_held() + self._module.showwarning = self._showwarning + self._module._showwarnmsg_impl = self._showwarnmsg_impl class deprecated: @@ -701,18 +713,36 @@ def extract(): # If either if the compiled regexs are None, match anything. try: from _warnings import (filters, _defaultaction, _onceregistry, - warn, warn_explicit, _filters_mutated) + warn, warn_explicit, + _filters_mutated_lock_held, + _acquire_lock, _release_lock, + ) defaultaction = _defaultaction onceregistry = _onceregistry _warnings_defaults = True + + class _Lock: + def __enter__(self): + _acquire_lock() + return self + + def __exit__(self, *args): + _release_lock() + + _lock = _Lock() + except ImportError: filters = [] defaultaction = "default" onceregistry = {} + import _thread + + _lock = _thread.RLock() + _filters_version = 1 - def _filters_mutated(): + def _filters_mutated_lock_held(): global _filters_version _filters_version += 1 diff --git a/Python/_warnings.c b/Python/_warnings.c index e05ba99e8eaec4..5302be64403614 100644 --- a/Python/_warnings.c +++ b/Python/_warnings.c @@ -232,6 +232,69 @@ get_warnings_attr(PyInterpreterState *interp, PyObject *attr, int try_import) return obj; } +static inline void +warnings_lock(PyInterpreterState *interp) +{ +#ifdef Py_GIL_DISABLED + WarningsState *st = warnings_get_state(interp); + assert(st != NULL); + _PyRecursiveMutex_Lock(&st->lock); +#endif +} + +static inline void +warnings_unlock(PyInterpreterState *interp) +{ +#ifdef Py_GIL_DISABLED + WarningsState *st = warnings_get_state(interp); + assert(st != NULL); + _PyRecursiveMutex_Unlock(&st->lock); +#endif +} + +static inline bool +warnings_lock_held(WarningsState *st) +{ +#ifdef Py_GIL_DISABLED + return PyMutex_IsLocked(&(st)->lock.mutex); +#else + return true; +#endif +} + +/*[clinic input] +_acquire_lock as warnings_acquire_lock + +[clinic start generated code]*/ + +static PyObject * +warnings_acquire_lock_impl(PyObject *module) +/*[clinic end generated code: output=594313457d1bf8e1 input=46ec20e55acca52f]*/ +{ + PyInterpreterState *interp = get_current_interp(); + if (interp == NULL) { + return NULL; + } + warnings_lock(interp); + Py_RETURN_NONE; +} + +/*[clinic input] +_release_lock as warnings_release_lock + +[clinic start generated code]*/ + +static PyObject * +warnings_release_lock_impl(PyObject *module) +/*[clinic end generated code: output=d73d5a8789396750 input=ea01bb77870c5693]*/ +{ + PyInterpreterState *interp = get_current_interp(); + if (interp == NULL) { + return NULL; + } + warnings_unlock(interp); + Py_RETURN_NONE; +} static PyObject * get_once_registry(PyInterpreterState *interp) @@ -239,7 +302,7 @@ get_once_registry(PyInterpreterState *interp) WarningsState *st = warnings_get_state(interp); assert(st != NULL); - _Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(&st->mutex); + assert(warnings_lock_held(st)); PyObject *registry = GET_WARNINGS_ATTR(interp, onceregistry, 0); if (registry == NULL) { @@ -267,7 +330,7 @@ get_default_action(PyInterpreterState *interp) WarningsState *st = warnings_get_state(interp); assert(st != NULL); - _Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(&st->mutex); + assert(warnings_lock_held(st)); PyObject *default_action = GET_WARNINGS_ATTR(interp, defaultaction, 0); if (default_action == NULL) { @@ -299,7 +362,7 @@ get_filter(PyInterpreterState *interp, PyObject *category, WarningsState *st = warnings_get_state(interp); assert(st != NULL); - _Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(&st->mutex); + assert(warnings_lock_held(st)); PyObject *warnings_filters = GET_WARNINGS_ATTR(interp, filters, 0); if (warnings_filters == NULL) { @@ -399,7 +462,7 @@ already_warned(PyInterpreterState *interp, PyObject *registry, PyObject *key, WarningsState *st = warnings_get_state(interp); assert(st != NULL); - _Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(&st->mutex); + assert(warnings_lock_held(st)); PyObject *version_obj; if (PyDict_GetItemRef(registry, &_Py_ID(version), &version_obj) < 0) { @@ -994,15 +1057,10 @@ do_warn(PyObject *message, PyObject *category, Py_ssize_t stack_level, &filename, &lineno, &module, ®istry)) return NULL; -#ifdef Py_GIL_DISABLED - WarningsState *st = warnings_get_state(tstate->interp); - assert(st != NULL); -#endif - - Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex); + warnings_lock(tstate->interp); res = warn_explicit(tstate, category, message, filename, lineno, module, registry, NULL, source); - Py_END_CRITICAL_SECTION(); + warnings_unlock(tstate->interp); Py_DECREF(filename); Py_DECREF(registry); Py_DECREF(module); @@ -1151,27 +1209,22 @@ warnings_warn_explicit_impl(PyObject *module, PyObject *message, } } -#ifdef Py_GIL_DISABLED - WarningsState *st = warnings_get_state(tstate->interp); - assert(st != NULL); -#endif - - Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex); + warnings_lock(tstate->interp); returned = warn_explicit(tstate, category, message, filename, lineno, mod, registry, source_line, sourceobj); - Py_END_CRITICAL_SECTION(); + warnings_unlock(tstate->interp); Py_XDECREF(source_line); return returned; } /*[clinic input] -_filters_mutated as warnings_filters_mutated +_filters_mutated_lock_held as warnings_filters_mutated_lock_held [clinic start generated code]*/ static PyObject * -warnings_filters_mutated_impl(PyObject *module) -/*[clinic end generated code: output=8ce517abd12b88f4 input=35ecbf08ee2491b2]*/ +warnings_filters_mutated_lock_held_impl(PyObject *module) +/*[clinic end generated code: output=df5c84f044e856ec input=34208bf03d70e432]*/ { PyInterpreterState *interp = get_current_interp(); if (interp == NULL) { @@ -1181,14 +1234,17 @@ warnings_filters_mutated_impl(PyObject *module) WarningsState *st = warnings_get_state(interp); assert(st != NULL); - Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex); + // Note that the lock must be held by the caller. + if (!warnings_lock_held(st)) { + PyErr_SetString(PyExc_RuntimeError, "warnings lock is not held"); + return NULL; + } + st->filters_version++; - Py_END_CRITICAL_SECTION(); Py_RETURN_NONE; } - /* Function to issue a warning message; may raise an exception. */ static int @@ -1303,15 +1359,10 @@ PyErr_WarnExplicitObject(PyObject *category, PyObject *message, return -1; } -#ifdef Py_GIL_DISABLED - WarningsState *st = warnings_get_state(tstate->interp); - assert(st != NULL); -#endif - - Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex); + warnings_lock(tstate->interp); res = warn_explicit(tstate, category, message, filename, lineno, module, registry, NULL, NULL); - Py_END_CRITICAL_SECTION(); + warnings_unlock(tstate->interp); if (res == NULL) return -1; Py_DECREF(res); @@ -1376,15 +1427,10 @@ PyErr_WarnExplicitFormat(PyObject *category, PyObject *res; PyThreadState *tstate = get_current_tstate(); if (tstate != NULL) { -#ifdef Py_GIL_DISABLED - WarningsState *st = warnings_get_state(tstate->interp); - assert(st != NULL); -#endif - - Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex); + warnings_lock(tstate->interp); res = warn_explicit(tstate, category, message, filename, lineno, module, registry, NULL, NULL); - Py_END_CRITICAL_SECTION(); + warnings_unlock(tstate->interp); Py_DECREF(message); if (res != NULL) { Py_DECREF(res); @@ -1464,7 +1510,9 @@ _PyErr_WarnUnawaitedCoroutine(PyObject *coro) static PyMethodDef warnings_functions[] = { WARNINGS_WARN_METHODDEF WARNINGS_WARN_EXPLICIT_METHODDEF - WARNINGS_FILTERS_MUTATED_METHODDEF + WARNINGS_FILTERS_MUTATED_LOCK_HELD_METHODDEF + WARNINGS_ACQUIRE_LOCK_METHODDEF + WARNINGS_RELEASE_LOCK_METHODDEF /* XXX(brett.cannon): add showwarning? */ /* XXX(brett.cannon): Reasonable to add formatwarning? */ {NULL, NULL} /* sentinel */ diff --git a/Python/clinic/_warnings.c.h b/Python/clinic/_warnings.c.h index 9a2c33f2ea8169..bcb4b344fa4370 100644 --- a/Python/clinic/_warnings.c.h +++ b/Python/clinic/_warnings.c.h @@ -9,6 +9,40 @@ preserve #include "pycore_abstract.h" // _PyNumber_Index() #include "pycore_modsupport.h" // _PyArg_UnpackKeywords() +PyDoc_STRVAR(warnings_acquire_lock__doc__, +"_acquire_lock($module, /)\n" +"--\n" +"\n"); + +#define WARNINGS_ACQUIRE_LOCK_METHODDEF \ + {"_acquire_lock", (PyCFunction)warnings_acquire_lock, METH_NOARGS, warnings_acquire_lock__doc__}, + +static PyObject * +warnings_acquire_lock_impl(PyObject *module); + +static PyObject * +warnings_acquire_lock(PyObject *module, PyObject *Py_UNUSED(ignored)) +{ + return warnings_acquire_lock_impl(module); +} + +PyDoc_STRVAR(warnings_release_lock__doc__, +"_release_lock($module, /)\n" +"--\n" +"\n"); + +#define WARNINGS_RELEASE_LOCK_METHODDEF \ + {"_release_lock", (PyCFunction)warnings_release_lock, METH_NOARGS, warnings_release_lock__doc__}, + +static PyObject * +warnings_release_lock_impl(PyObject *module); + +static PyObject * +warnings_release_lock(PyObject *module, PyObject *Py_UNUSED(ignored)) +{ + return warnings_release_lock_impl(module); +} + PyDoc_STRVAR(warnings_warn__doc__, "warn($module, /, message, category=None, stacklevel=1, source=None, *,\n" " skip_file_prefixes=)\n" @@ -230,20 +264,20 @@ warnings_warn_explicit(PyObject *module, PyObject *const *args, Py_ssize_t nargs return return_value; } -PyDoc_STRVAR(warnings_filters_mutated__doc__, -"_filters_mutated($module, /)\n" +PyDoc_STRVAR(warnings_filters_mutated_lock_held__doc__, +"_filters_mutated_lock_held($module, /)\n" "--\n" "\n"); -#define WARNINGS_FILTERS_MUTATED_METHODDEF \ - {"_filters_mutated", (PyCFunction)warnings_filters_mutated, METH_NOARGS, warnings_filters_mutated__doc__}, +#define WARNINGS_FILTERS_MUTATED_LOCK_HELD_METHODDEF \ + {"_filters_mutated_lock_held", (PyCFunction)warnings_filters_mutated_lock_held, METH_NOARGS, warnings_filters_mutated_lock_held__doc__}, static PyObject * -warnings_filters_mutated_impl(PyObject *module); +warnings_filters_mutated_lock_held_impl(PyObject *module); static PyObject * -warnings_filters_mutated(PyObject *module, PyObject *Py_UNUSED(ignored)) +warnings_filters_mutated_lock_held(PyObject *module, PyObject *Py_UNUSED(ignored)) { - return warnings_filters_mutated_impl(module); + return warnings_filters_mutated_lock_held_impl(module); } -/*[clinic end generated code: output=ed02c0f521a03a37 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=d9d32a8b59a30683 input=a9049054013a1b77]*/ From fbd62898085530136ff00d376e44ef7e245fe030 Mon Sep 17 00:00:00 2001 From: Neil Schemenauer Date: Fri, 27 Dec 2024 13:31:18 -0800 Subject: [PATCH 4/4] Use contextvar for catch_warnings(). --- .../pycore_global_objects_fini_generated.h | 1 + Include/internal/pycore_global_strings.h | 1 + .../internal/pycore_runtime_init_generated.h | 1 + .../internal/pycore_unicodeobject_generated.h | 4 + Lib/test/test_warnings/__init__.py | 46 ++--- Lib/warnings.py | 122 ++++++++--- Python/_warnings.c | 190 ++++++++++++++---- 7 files changed, 281 insertions(+), 84 deletions(-) diff --git a/Include/internal/pycore_global_objects_fini_generated.h b/Include/internal/pycore_global_objects_fini_generated.h index 90214a314031d1..5172f4f63d2c55 100644 --- a/Include/internal/pycore_global_objects_fini_generated.h +++ b/Include/internal/pycore_global_objects_fini_generated.h @@ -776,6 +776,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(_type_)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(_uninitialized_submodules)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(_warn_unawaited_coroutine)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(_warnings_context)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(_xoptions)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(abs_tol)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(access)); diff --git a/Include/internal/pycore_global_strings.h b/Include/internal/pycore_global_strings.h index 97a75d0c46c867..433fe58714ed72 100644 --- a/Include/internal/pycore_global_strings.h +++ b/Include/internal/pycore_global_strings.h @@ -265,6 +265,7 @@ struct _Py_global_strings { STRUCT_FOR_ID(_type_) STRUCT_FOR_ID(_uninitialized_submodules) STRUCT_FOR_ID(_warn_unawaited_coroutine) + STRUCT_FOR_ID(_warnings_context) STRUCT_FOR_ID(_xoptions) STRUCT_FOR_ID(abs_tol) STRUCT_FOR_ID(access) diff --git a/Include/internal/pycore_runtime_init_generated.h b/Include/internal/pycore_runtime_init_generated.h index 4f928cc050bf8e..cf32fe451b73e5 100644 --- a/Include/internal/pycore_runtime_init_generated.h +++ b/Include/internal/pycore_runtime_init_generated.h @@ -774,6 +774,7 @@ extern "C" { INIT_ID(_type_), \ INIT_ID(_uninitialized_submodules), \ INIT_ID(_warn_unawaited_coroutine), \ + INIT_ID(_warnings_context), \ INIT_ID(_xoptions), \ INIT_ID(abs_tol), \ INIT_ID(access), \ diff --git a/Include/internal/pycore_unicodeobject_generated.h b/Include/internal/pycore_unicodeobject_generated.h index 5b78d038fc1192..de8d27f1892b5b 100644 --- a/Include/internal/pycore_unicodeobject_generated.h +++ b/Include/internal/pycore_unicodeobject_generated.h @@ -856,6 +856,10 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) { _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); assert(PyUnicode_GET_LENGTH(string) != 1); + string = &_Py_ID(_warnings_context); + _PyUnicode_InternStatic(interp, &string); + assert(_PyUnicode_CheckConsistency(string, 1)); + assert(PyUnicode_GET_LENGTH(string) != 1); string = &_Py_ID(_xoptions); _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); diff --git a/Lib/test/test_warnings/__init__.py b/Lib/test/test_warnings/__init__.py index 4bd164b8a9a82b..a2e43ea4f3a72b 100644 --- a/Lib/test/test_warnings/__init__.py +++ b/Lib/test/test_warnings/__init__.py @@ -43,15 +43,14 @@ def warnings_state(module): except NameError: pass original_warnings = warning_tests.warnings - original_filters = module.filters + saved_context, context = module._new_context() try: - module.filters = original_filters[:] module.simplefilter("once") warning_tests.warnings = module yield finally: warning_tests.warnings = original_warnings - module.filters = original_filters + module._set_context(saved_context) class TestWarning(Warning): @@ -336,15 +335,15 @@ def test_filterwarnings_duplicate_filters(self): with original_warnings.catch_warnings(module=self.module): self.module.resetwarnings() self.module.filterwarnings("error", category=UserWarning) - self.assertEqual(len(self.module.filters), 1) + self.assertEqual(len(self.module._get_filters()), 1) self.module.filterwarnings("ignore", category=UserWarning) self.module.filterwarnings("error", category=UserWarning) self.assertEqual( - len(self.module.filters), 2, + len(self.module._get_filters()), 2, "filterwarnings inserted duplicate filter" ) self.assertEqual( - self.module.filters[0][0], "error", + self.module._get_filters()[0][0], "error", "filterwarnings did not promote filter to " "the beginning of list" ) @@ -353,15 +352,15 @@ def test_simplefilter_duplicate_filters(self): with original_warnings.catch_warnings(module=self.module): self.module.resetwarnings() self.module.simplefilter("error", category=UserWarning) - self.assertEqual(len(self.module.filters), 1) + self.assertEqual(len(self.module._get_filters()), 1) self.module.simplefilter("ignore", category=UserWarning) self.module.simplefilter("error", category=UserWarning) self.assertEqual( - len(self.module.filters), 2, + len(self.module._get_filters()), 2, "simplefilter inserted duplicate filter" ) self.assertEqual( - self.module.filters[0][0], "error", + self.module._get_filters()[0][0], "error", "simplefilter did not promote filter to the beginning of list" ) @@ -373,7 +372,7 @@ def test_append_duplicate(self): self.module.simplefilter("error", append=True) self.module.simplefilter("ignore", append=True) self.module.warn("test_append_duplicate", category=UserWarning) - self.assertEqual(len(self.module.filters), 2, + self.assertEqual(len(self.module._get_filters()), 2, "simplefilter inserted duplicate filter" ) self.assertEqual(len(w), 0, @@ -1049,11 +1048,11 @@ def test_issue31416(self): # bad warnings.filters or warnings.defaultaction. wmod = self.module with original_warnings.catch_warnings(module=wmod): - wmod.filters = [(None, None, Warning, None, 0)] + wmod._get_filters()[:] = [(None, None, Warning, None, 0)] with self.assertRaises(TypeError): wmod.warn_explicit('foo', Warning, 'bar', 1) - wmod.filters = [] + wmod._get_filters()[:] = [] with support.swap_attr(wmod, 'defaultaction', None), \ self.assertRaises(TypeError): wmod.warn_explicit('foo', Warning, 'bar', 1) @@ -1191,17 +1190,17 @@ class CatchWarningTests(BaseTest): def test_catch_warnings_restore(self): wmod = self.module - orig_filters = wmod.filters + orig_filters = wmod._get_filters() orig_showwarning = wmod.showwarning # Ensure both showwarning and filters are restored when recording with wmod.catch_warnings(module=wmod, record=True): - wmod.filters = wmod.showwarning = object() - self.assertIs(wmod.filters, orig_filters) + wmod.get_context()._filters = wmod.showwarning = object() + self.assertIs(wmod._get_filters(), orig_filters) self.assertIs(wmod.showwarning, orig_showwarning) # Same test, but with recording disabled with wmod.catch_warnings(module=wmod, record=False): - wmod.filters = wmod.showwarning = object() - self.assertIs(wmod.filters, orig_filters) + wmod.get_context()._filters = wmod.showwarning = object() + self.assertIs(wmod._get_filters(), orig_filters) self.assertIs(wmod.showwarning, orig_showwarning) def test_catch_warnings_recording(self): @@ -1240,21 +1239,21 @@ def test_catch_warnings_reentry_guard(self): def test_catch_warnings_defaults(self): wmod = self.module - orig_filters = wmod.filters + orig_filters = wmod._get_filters() orig_showwarning = wmod.showwarning # Ensure default behaviour is not to record warnings with wmod.catch_warnings(module=wmod) as w: self.assertIsNone(w) self.assertIs(wmod.showwarning, orig_showwarning) - self.assertIsNot(wmod.filters, orig_filters) - self.assertIs(wmod.filters, orig_filters) + self.assertIsNot(wmod._get_filters(), orig_filters) + self.assertIs(wmod._get_filters(), orig_filters) if wmod is sys.modules['warnings']: # Ensure the default module is this one with wmod.catch_warnings() as w: self.assertIsNone(w) self.assertIs(wmod.showwarning, orig_showwarning) - self.assertIsNot(wmod.filters, orig_filters) - self.assertIs(wmod.filters, orig_filters) + self.assertIsNot(wmod._get_filters(), orig_filters) + self.assertIs(wmod._get_filters(), orig_filters) def test_record_override_showwarning_before(self): # Issue #28835: If warnings.showwarning() was overridden, make sure @@ -1406,7 +1405,7 @@ def test_default_filter_configuration(self): code = "import sys; sys.modules.pop('warnings', None); sys.modules['_warnings'] = None; " else: code = "" - code += "import warnings; [print(f) for f in warnings.filters]" + code += "import warnings; [print(f) for f in warnings._get_filters()]" rc, stdout, stderr = assert_python_ok("-c", code, __isolated=True) stdout_lines = [line.strip() for line in stdout.splitlines()] @@ -1821,6 +1820,7 @@ async def coro(self): self.assertFalse(inspect.iscoroutinefunction(Cls.sync)) self.assertTrue(inspect.iscoroutinefunction(Cls.coro)) + def setUpModule(): py_warnings.onceregistry.clear() c_warnings.onceregistry.clear() diff --git a/Lib/warnings.py b/Lib/warnings.py index d679974a46c658..7a96919fef9ae6 100644 --- a/Lib/warnings.py +++ b/Lib/warnings.py @@ -1,12 +1,73 @@ """Python part of the warnings subsystem.""" import sys +import itertools as _itertools +import contextvars as _contextvars __all__ = ["warn", "warn_explicit", "showwarning", "formatwarning", "filterwarnings", "simplefilter", "resetwarnings", "catch_warnings", "deprecated"] +class _Context: + def __init__(self, filters): + self._filters = filters + self.log = None # if set to a list, logging is enabled + + def copy(self): + context = _Context(self._filters[:]) + if self.log is not None: + context.log = self.log + return context + + def _record_warning(self, msg): + self.log.append(msg) + + +class _GlobalContext(_Context): + def __init__(self): + self.log = None + + @property + def _filters(self): + # Since there is quite a lot of code that assigns to + # warnings.filters, this needs to return the current value of + # the module global. + try: + return filters + except NameError: + # 'filters' global was deleted. Do we need to actually handle this case? + return [] + +_global_context = _GlobalContext() +_warnings_context = _contextvars.ContextVar('warnings_context') + +def get_context(): + try: + context = _warnings_context.get() + except LookupError: + context = _global_context + _warnings_context.set(context) + return context + + +def _set_context(context): + _warnings_context.set(context) + + +def _new_context(): + old_context = get_context() + new_context = old_context.copy() + _set_context(new_context) + return old_context, new_context + + +def _get_filters(): + """Return the current list of filters. This is a non-public API used by + the unit tests.""" + return get_context()._filters + + def showwarning(message, category, filename, lineno, file=None, line=None): """Hook to write a warning to a file; replace if you like.""" msg = WarningMessage(message, category, filename, lineno, file, line) @@ -18,6 +79,10 @@ def formatwarning(message, category, filename, lineno, line=None): return _formatwarnmsg_impl(msg) def _showwarnmsg_impl(msg): + context = get_context() + if context.log is not None: + context._record_warning(msg) + return file = msg.file if file is None: file = sys.stderr @@ -129,7 +194,7 @@ def _formatwarnmsg(msg): return _formatwarnmsg_impl(msg) def filterwarnings(action, message="", category=Warning, module="", lineno=0, - append=False): + append=False, *, context=None): """Insert an entry into the list of warnings filters (at the front). 'action' -- one of "error", "ignore", "always", "all", "default", "module", @@ -165,9 +230,11 @@ def filterwarnings(action, message="", category=Warning, module="", lineno=0, else: module = None - _add_filter(action, message, category, module, lineno, append=append) + _add_filter(action, message, category, module, lineno, append=append, + context=context) -def simplefilter(action, category=Warning, lineno=0, append=False): +def simplefilter(action, category=Warning, lineno=0, append=False, *, + context=None): """Insert a simple entry into the list of warnings filters (at the front). A simple filter matches all modules and messages. @@ -183,7 +250,8 @@ def simplefilter(action, category=Warning, lineno=0, append=False): raise TypeError("lineno must be an int") if lineno < 0: raise ValueError("lineno must be an int >= 0") - _add_filter(action, None, category, None, lineno, append=append) + _add_filter(action, None, category, None, lineno, append=append, + context=context) def _filters_mutated(): # Even though this function is part of the public API, it's used by @@ -191,8 +259,11 @@ def _filters_mutated(): with _lock: _filters_mutated_lock_held() -def _add_filter(*item, append): +def _add_filter(*item, append, context=None): with _lock: + if context is None: + context = get_context() + filters = context._filters if not append: # Remove possible duplicate filters, so new one will be placed # in correct place. If append=True and duplicate exists, do nothing. @@ -206,10 +277,12 @@ def _add_filter(*item, append): filters.append(item) _filters_mutated_lock_held() -def resetwarnings(): +def resetwarnings(*, context=None): """Clear the list of warning filters, so that no filters are active.""" with _lock: - filters[:] = [] + if context is None: + context = get_context() + del context._filters[:] _filters_mutated_lock_held() class _OptionError(Exception): @@ -378,7 +451,7 @@ def warn_explicit(message, category, filename, lineno, if registry.get(key): return # Search the filters - for item in filters: + for item in get_context()._filters: action, msg, cat, mod, ln = item if ((msg is None or msg.match(text)) and issubclass(category, cat) and @@ -499,31 +572,28 @@ def __enter__(self): raise RuntimeError("Cannot enter %r twice" % self) self._entered = True with _lock: - self._filters = self._module.filters - self._module.filters = self._filters[:] - self._module._filters_mutated_lock_held() + self._saved_context, context = self._module._new_context() self._showwarning = self._module.showwarning self._showwarnmsg_impl = self._module._showwarnmsg_impl + if self._record: + context.log = log = [] + # Reset showwarning() to the default implementation to make sure + # that _showwarnmsg() calls _showwarnmsg_impl() + self._module.showwarning = self._module._showwarning_orig + else: + log = None if self._filter is not None: - simplefilter(*self._filter) - if self._record: - log = [] - self._module._showwarnmsg_impl = log.append - # Reset showwarning() to the default implementation to make sure - # that _showwarnmsg() calls _showwarnmsg_impl() - self._module.showwarning = self._module._showwarning_orig - return log - else: - return None + self._module.simplefilter(*self._filter, context=context) + return log def __exit__(self, *exc_info): if not self._entered: raise RuntimeError("Cannot exit %r without entering first" % self) with _lock: - self._module.filters = self._filters - self._module._filters_mutated_lock_held() + self._module._warnings_context.set(self._saved_context) self._module.showwarning = self._showwarning self._module._showwarnmsg_impl = self._showwarnmsg_impl + self._module._filters_mutated_lock_held() class deprecated: @@ -762,3 +832,9 @@ def _filters_mutated_lock_held(): simplefilter("ignore", category=ResourceWarning, append=1) del _warnings_defaults + +#def __getattr__(name): +# if name == "filters": +# warn('Accessing warnings.filters is likely not thread-safe.', DeprecationWarning, stacklevel=2) +# return get_context()._filters +# raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/Python/_warnings.c b/Python/_warnings.c index 5302be64403614..b86a50c15e1258 100644 --- a/Python/_warnings.c +++ b/Python/_warnings.c @@ -262,6 +262,87 @@ warnings_lock_held(WarningsState *st) #endif } +static PyObject * +get_warnings_context(PyInterpreterState *interp) +{ + PyObject *ctx_var = GET_WARNINGS_ATTR(interp, _warnings_context, 0); + if (ctx_var == NULL) { + if (!PyErr_Occurred()) { + // likely that the 'warnings' module doesn't exist anymore + Py_RETURN_NONE; + } + else { + return NULL; + } + } + if (!PyContextVar_CheckExact(ctx_var)) { + PyErr_Format(PyExc_TypeError, + MODULE_NAME "._warnings_defaults must be a ContextVar, " + "not '%.200s'", + Py_TYPE(ctx_var)->tp_name); + Py_DECREF(ctx_var); + return NULL; + } + PyObject *ctx; + if (PyContextVar_Get(ctx_var, NULL, &ctx) < 0) { + Py_DECREF(ctx_var); + return NULL; + } + Py_DECREF(ctx_var); + if (ctx == NULL) { + Py_RETURN_NONE; + } + return ctx; +} + +static PyObject * +get_warnings_context_filters(PyInterpreterState *interp) +{ + PyObject *ctx = get_warnings_context(interp); + if (ctx == NULL) { + return NULL; + } + if (ctx == Py_None) { + Py_DECREF(ctx); + Py_RETURN_NONE; + } + PyObject *context_filters = PyObject_GetAttrString(ctx, "_filters"); + Py_DECREF(ctx); + if (context_filters == NULL) { + return NULL; + } + if (!PyList_Check(context_filters)) { + PyErr_SetString(PyExc_ValueError, + "warnings._warnings_context _filters must be a list"); + Py_DECREF(context_filters); + return NULL; + } + return context_filters; +} + +// Returns a borrowed reference to the list. +static PyObject * +get_warnings_filters(PyInterpreterState *interp) +{ + WarningsState *st = warnings_get_state(interp); + PyObject *warnings_filters = GET_WARNINGS_ATTR(interp, filters, 0); + if (warnings_filters == NULL) { + if (PyErr_Occurred()) + return NULL; + } + else { + Py_SETREF(st->filters, warnings_filters); + } + + PyObject *filters = st->filters; + if (filters == NULL || !PyList_Check(filters)) { + PyErr_SetString(PyExc_ValueError, + MODULE_NAME ".filters must be a list"); + return NULL; + } + return filters; +} + /*[clinic input] _acquire_lock as warnings_acquire_lock @@ -352,35 +433,14 @@ get_default_action(PyInterpreterState *interp) return default_action; } - -/* The item is a new reference. */ -static PyObject* -get_filter(PyInterpreterState *interp, PyObject *category, - PyObject *text, Py_ssize_t lineno, - PyObject *module, PyObject **item) -{ - WarningsState *st = warnings_get_state(interp); - assert(st != NULL); - - assert(warnings_lock_held(st)); - - PyObject *warnings_filters = GET_WARNINGS_ATTR(interp, filters, 0); - if (warnings_filters == NULL) { - if (PyErr_Occurred()) - return NULL; - } - else { - Py_SETREF(st->filters, warnings_filters); - } - - PyObject *filters = st->filters; - if (filters == NULL || !PyList_Check(filters)) { - PyErr_SetString(PyExc_ValueError, - MODULE_NAME ".filters must be a list"); - return NULL; - } - - /* WarningsState.filters could change while we are iterating over it. */ +/* Search filters list of match, returns false on error. If no match + * then 'matched_action' is NULL. */ +static bool +filter_search(PyInterpreterState *interp, PyObject *category, + PyObject *text, Py_ssize_t lineno, + PyObject *module, char *list_name, PyObject *filters, + PyObject **item, PyObject **matched_action) { + /* filters list could change while we are iterating over it. */ for (Py_ssize_t i = 0; i < PyList_GET_SIZE(filters); i++) { PyObject *tmp_item, *action, *msg, *cat, *mod, *ln_obj; Py_ssize_t ln; @@ -389,8 +449,8 @@ get_filter(PyInterpreterState *interp, PyObject *category, tmp_item = PyList_GET_ITEM(filters, i); if (!PyTuple_Check(tmp_item) || PyTuple_GET_SIZE(tmp_item) != 5) { PyErr_Format(PyExc_ValueError, - MODULE_NAME ".filters item %zd isn't a 5-tuple", i); - return NULL; + "warnings.%s item %zd isn't a 5-tuple", list_name, i); + return false; } /* Python code: action, msg, cat, mod, ln = item */ @@ -406,42 +466,96 @@ get_filter(PyInterpreterState *interp, PyObject *category, "action must be a string, not '%.200s'", Py_TYPE(action)->tp_name); Py_DECREF(tmp_item); - return NULL; + return false; } good_msg = check_matched(interp, msg, text); if (good_msg == -1) { Py_DECREF(tmp_item); - return NULL; + return false; } good_mod = check_matched(interp, mod, module); if (good_mod == -1) { Py_DECREF(tmp_item); - return NULL; + return false; } is_subclass = PyObject_IsSubclass(category, cat); if (is_subclass == -1) { Py_DECREF(tmp_item); - return NULL; + return false; } ln = PyLong_AsSsize_t(ln_obj); if (ln == -1 && PyErr_Occurred()) { Py_DECREF(tmp_item); - return NULL; + return false; } if (good_msg && is_subclass && good_mod && (ln == 0 || lineno == ln)) { *item = tmp_item; - return action; + *matched_action = action; + return true; } Py_DECREF(tmp_item); } + *matched_action = NULL; + return true; +} + +/* The item is a new reference. */ +static PyObject* +get_filter(PyInterpreterState *interp, PyObject *category, + PyObject *text, Py_ssize_t lineno, + PyObject *module, PyObject **item) +{ + WarningsState *st = warnings_get_state(interp); + assert(st != NULL); + + assert(warnings_lock_held(st)); + + /* check _warning_context _filters list */ + PyObject *context_filters = get_warnings_context_filters(interp); + bool use_global_filters = false; + if (context_filters == NULL) { + return NULL; + } + if (context_filters == Py_None) { + use_global_filters = true; + Py_DECREF(context_filters); + } else { + PyObject *context_action = NULL; + if (!filter_search(interp, category, text, lineno, module, "_warnings_context _filters", + context_filters, item, &context_action)) { + Py_DECREF(context_filters); + return NULL; + } + Py_DECREF(context_filters); + if (context_action != NULL) { + return context_action; + } + } + + PyObject *action; + + if (use_global_filters) { + /* check warnings.filters list */ + PyObject *filters = get_warnings_filters(interp); + if (filters == NULL) { + return NULL; + } + if (!filter_search(interp, category, text, lineno, module, "filters", + filters, item, &action)) { + return NULL; + } + if (action != NULL) { + return action; + } + } - PyObject *action = get_default_action(interp); + action = get_default_action(interp); if (action != NULL) { *item = Py_NewRef(Py_None); return action;