Skip to content

Commit

Permalink
pythongh-124872: Mark the thread's default context as entered
Browse files Browse the repository at this point in the history
Starting with commit 843d28f
(temporarily reverted in d3c82b9 and
restored in commit bee112a), it is
now technically possible to access a thread's default context created
by `context_get`.  Mark that context as entered so that users cannot
push that context onto the thread's stack a second time, which would
cause a cycle.

Also exit that context when the thread exits, for symmetry and in case
the user wants to re-enter it for some reason.

(Even if the `CONTEXT_SWITCHED` event is removed, entering the default
context is good defensive practice, and the consistent treatment of
all contexts on the stack makes it easier to understand the code.)
  • Loading branch information
rhansen committed Oct 17, 2024
1 parent 8e7b2a1 commit 1726539
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 10 deletions.
13 changes: 12 additions & 1 deletion Include/internal/pycore_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ extern PyTypeObject _PyContextTokenMissing_Type;

PyStatus _PyContext_Init(PyInterpreterState *);

// Exits any thread-owned contexts (see context_get) at the top of the thread's
// context stack. Logs a warning via PyErr_FormatUnraisable if the thread's
// context stack is non-empty afterwards (those contexts can never be exited or
// re-entered). The given thread state must belong to the thread calling this
// function.
void _PyContext_ExitThreadOwned(PyThreadState *);


/* other API */

Expand All @@ -27,7 +34,11 @@ struct _pycontextobject {
PyContext *ctx_prev;
PyHamtObject *ctx_vars;
PyObject *ctx_weakreflist;
int ctx_entered;
_Bool ctx_entered:1;
// True for the thread's default context created by context_get. Used to
// safely determine whether the base context can be exited when clearing a
// PyThreadState.
_Bool ctx_owned_by_thread:1;
};


Expand Down
71 changes: 71 additions & 0 deletions Lib/test/test_capi/test_watchers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys
import threading
import unittest
import contextvars

Expand Down Expand Up @@ -659,5 +661,74 @@ def test_exit_base_context(self):
ctx.run(lambda: None)
self.assertEqual(switches, [ctx, None])

def test_reenter_default_context(self):
_testcapi.clear_context_stack()
# contextvars.copy_context() creates the thread's default context (via
# the context_get C function).
ctx = contextvars.copy_context()
with self.context_watcher(0) as switches:
ctx.run(lambda: None)
self.assertEqual(len(switches), 2)
self.assertEqual(switches[0], ctx)
base_ctx = switches[1]
self.assertIsNotNone(base_ctx)
self.assertIsNot(base_ctx, ctx)
with self.assertRaisesRegex(RuntimeError, 'already entered'):
base_ctx.run(lambda: None)

def test_default_context_enter(self):
_testcapi.clear_context_stack()
with self.context_watcher(0) as switches:
ctx = contextvars.copy_context()
ctx.run(lambda: None)
self.assertEqual(len(switches), 3)
base_ctx = switches[0]
self.assertIsNotNone(base_ctx)
self.assertEqual(switches, [base_ctx, ctx, base_ctx])

def test_default_context_exit_during_thread_cleanup(self):
# Context watchers are per-interpreter, not per-thread.
# https://discuss.python.org/t/v3-14a1-design-limitations-of-pycontext-addwatcher/68177
with self.context_watcher(0) as switches:
def _thread_main():
_testcapi.clear_context_stack()
# contextvars.copy_context() creates the thread's default
# context (via the context_get C function).
contextvars.copy_context()
# This test only cares about the final switch that happens when
# exiting the thread's default context during thread cleanup.
switches.clear()

thread = threading.Thread(target=_thread_main)
thread.start()
thread.join()
self.assertEqual(switches, [None])

def test_thread_cleanup_with_entered_context(self):
unraisables = []
try:
with catch_unraisable_exception() as cm:
with self.context_watcher(0) as switches:
def _thread_main():
_testcapi.clear_context_stack()
ctx = contextvars.copy_context()
_testcapi.context_enter(ctx)
switches.clear()

thread = threading.Thread(target=_thread_main)
thread.start()
thread.join()
unraisables.append(cm.unraisable)
self.assertEqual(switches, [])
self.assertEqual(len(unraisables), 1)
self.assertIsNotNone(unraisables[0])
self.assertRegex(unraisables[0].err_msg,
r'^Exception ignored during reset of thread state')
self.assertRegex(str(unraisables[0].exc_value), r'still entered')
finally:
# Break reference cycle
unraisables = None


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions Modules/_testcapi/watchers.c
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,15 @@ clear_context_stack(PyObject *Py_UNUSED(self), PyObject *Py_UNUSED(args))
Py_RETURN_NONE;
}

static PyObject *
context_enter(PyObject *self, PyObject *ctx)
{
if (PyContext_Enter(ctx)) {
return NULL;
}
Py_RETURN_NONE;
}

static PyObject *
get_context_switches(PyObject *Py_UNUSED(self), PyObject *watcher_id)
{
Expand Down Expand Up @@ -841,6 +850,7 @@ static PyMethodDef test_methods[] = {
{"add_context_watcher", add_context_watcher, METH_O, NULL},
{"clear_context_watcher", clear_context_watcher, METH_O, NULL},
{"clear_context_stack", clear_context_stack, METH_NOARGS, NULL},
{"context_enter", context_enter, METH_O, NULL},
{"get_context_switches", get_context_switches, METH_O, NULL},
{"allocate_too_many_context_watchers",
(PyCFunction) allocate_too_many_context_watchers, METH_NOARGS, NULL},
Expand Down
53 changes: 44 additions & 9 deletions Python/context.c
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,9 @@ static inline void
context_switched(PyThreadState *ts)
{
ts->context_ver++;
// ts->context is used instead of context_get() because context_get() might
// throw if ts->context is NULL.
// ts->context is used instead of context_get() because if ts->context is
// NULL, context_get() will either call context_switched -- causing a
// double notification -- or throw.
notify_context_watchers(ts, Py_CONTEXT_SWITCHED, ts->context);
}

Expand Down Expand Up @@ -244,6 +245,7 @@ _PyContext_Exit(PyThreadState *ts, PyObject *octx)

ctx->ctx_prev = NULL;
ctx->ctx_entered = 0;
ctx->ctx_owned_by_thread = 0;
context_switched(ts);
return 0;
}
Expand All @@ -257,6 +259,32 @@ PyContext_Exit(PyObject *octx)
}


void
_PyContext_ExitThreadOwned(PyThreadState *ts)
{
assert(ts != NULL);
// ts must belong to the current thread because PyErr_FormatUnraisable
// operates on the current thread (it calls sys.unraisablehook).
assert(ts == _PyThreadState_GET());
while (ts->context != NULL
&& PyContext_CheckExact(ts->context)
&& ((PyContext *)ts->context)->ctx_owned_by_thread) {
if (_PyContext_Exit(ts, ts->context)) {
Py_UNREACHABLE();
}
}
if (ts->context != NULL) {
PyObject *exc = _PyErr_GetRaisedException(ts);
_PyErr_SetString(ts, PyExc_RuntimeError,
"contextvars.Context object(s) still entered during "
"thread state reset");
PyErr_FormatUnraisable(
"Exception ignored during reset of thread state %p", ts);
_PyErr_SetRaisedException(ts, exc);
}
}


PyObject *
PyContextVar_New(const char *name, PyObject *def)
{
Expand Down Expand Up @@ -433,6 +461,7 @@ _context_alloc(void)
ctx->ctx_vars = NULL;
ctx->ctx_prev = NULL;
ctx->ctx_entered = 0;
ctx->ctx_owned_by_thread = 0;
ctx->ctx_weakreflist = NULL;

return ctx;
Expand Down Expand Up @@ -478,15 +507,21 @@ context_get(void)
{
PyThreadState *ts = _PyThreadState_GET();
assert(ts != NULL);
PyContext *current_ctx = (PyContext *)ts->context;
if (current_ctx == NULL) {
current_ctx = context_new_empty();
if (current_ctx == NULL) {
return NULL;
if (ts->context == NULL) {
PyContext *ctx = context_new_empty();
if (ctx != NULL) {
if (_PyContext_Enter(ts, (PyObject *)ctx)) {
Py_UNREACHABLE();
}
ctx->ctx_owned_by_thread = 1;
}
ts->context = (PyObject *)current_ctx;
assert(ts->context == (PyObject *)ctx);
Py_CLEAR(ctx); // _PyContext_Enter created its own ref.
}
return current_ctx;
// The current context may be NULL if the above context_new_empty() call
// failed.
assert(ts->context == NULL || PyContext_CheckExact(ts->context));
return (PyContext *)ts->context;
}

static int
Expand Down
4 changes: 4 additions & 0 deletions Python/pystate.c
Original file line number Diff line number Diff line change
Expand Up @@ -1650,6 +1650,10 @@ PyThreadState_Clear(PyThreadState *tstate)
"PyThreadState_Clear: warning: thread still has a frame\n");
}

// This calls callbacks registered with PyContext_AddWatcher and can call
// sys.unraisablehook.
_PyContext_ExitThreadOwned(tstate);

/* At this point tstate shouldn't be used any more,
neither to run Python code nor for other uses.
Expand Down

0 comments on commit 1726539

Please sign in to comment.