diff --git a/Include/internal/pycore_crossinterp.h b/Include/internal/pycore_crossinterp.h index 19c55dd65983d7..3a04d9f44847d2 100644 --- a/Include/internal/pycore_crossinterp.h +++ b/Include/internal/pycore_crossinterp.h @@ -131,7 +131,23 @@ PyAPI_FUNC(void) _PyXIData_Clear(PyInterpreterState *, _PyXIData_t *); /* getting cross-interpreter data */ -typedef int (*xidatafunc)(PyThreadState *tstate, PyObject *, _PyXIData_t *); +typedef int xidata_fallback_t; +#define _PyXIDATA_XIDATA_ONLY (0) +#define _PyXIDATA_FULL_FALLBACK (1) + +// Technically, we don't need two different function types; +// we could go with just the fallback one. However, only container +// types like tuple need it, so always having the extra arg would be +// a bit unfortunate. It's also nice to be able to clearly distinguish +// between types that might call _PyObject_GetXIData() and those that won't. +// +typedef int (*xidatafunc)(PyThreadState *, PyObject *, _PyXIData_t *); +typedef int (*xidatafbfunc)( + PyThreadState *, PyObject *, xidata_fallback_t, _PyXIData_t *); +typedef struct { + xidatafunc basic; + xidatafbfunc fallback; +} _PyXIData_getdata_t; PyAPI_FUNC(PyObject *) _PyXIData_GetNotShareableErrorType(PyThreadState *); PyAPI_FUNC(void) _PyXIData_SetNotShareableError(PyThreadState *, const char *); @@ -140,7 +156,7 @@ PyAPI_FUNC(void) _PyXIData_FormatNotShareableError( const char *, ...); -PyAPI_FUNC(xidatafunc) _PyXIData_Lookup( +PyAPI_FUNC(_PyXIData_getdata_t) _PyXIData_Lookup( PyThreadState *, PyObject *); PyAPI_FUNC(int) _PyObject_CheckXIData( @@ -151,6 +167,11 @@ PyAPI_FUNC(int) _PyObject_GetXIData( PyThreadState *, PyObject *, _PyXIData_t *); +PyAPI_FUNC(int) _PyObject_GetXIDataWithFallback( + PyThreadState *, + PyObject *, + xidata_fallback_t, + _PyXIData_t *); // _PyObject_GetXIData() for bytes typedef struct { @@ -314,24 +335,9 @@ typedef struct _sharedexception { PyAPI_FUNC(PyObject *) _PyXI_ApplyError(_PyXI_error *err); -typedef struct xi_session _PyXI_session; -typedef struct _sharedns _PyXI_namespace; - -PyAPI_FUNC(void) _PyXI_FreeNamespace(_PyXI_namespace *ns); -PyAPI_FUNC(_PyXI_namespace *) _PyXI_NamespaceFromNames(PyObject *names); -PyAPI_FUNC(int) _PyXI_FillNamespaceFromDict( - _PyXI_namespace *ns, - PyObject *nsobj, - _PyXI_session *session); -PyAPI_FUNC(int) _PyXI_ApplyNamespace( - _PyXI_namespace *ns, - PyObject *nsobj, - PyObject *dflt); - - // A cross-interpreter session involves entering an interpreter -// (_PyXI_Enter()), doing some work with it, and finally exiting -// that interpreter (_PyXI_Exit()). +// with _PyXI_Enter(), doing some work with it, and finally exiting +// that interpreter with _PyXI_Exit(). // // At the boundaries of the session, both entering and exiting, // data may be exchanged between the previous interpreter and the @@ -339,39 +345,10 @@ PyAPI_FUNC(int) _PyXI_ApplyNamespace( // isolation between interpreters. This includes setting objects // in the target's __main__ module on the way in, and capturing // uncaught exceptions on the way out. -struct xi_session { - // Once a session has been entered, this is the tstate that was - // current before the session. If it is different from cur_tstate - // then we must have switched interpreters. Either way, this will - // be the current tstate once we exit the session. - PyThreadState *prev_tstate; - // Once a session has been entered, this is the current tstate. - // It must be current when the session exits. - PyThreadState *init_tstate; - // This is true if init_tstate needs cleanup during exit. - int own_init_tstate; - - // This is true if, while entering the session, init_thread took - // "ownership" of the interpreter's __main__ module. This means - // it is the only thread that is allowed to run code there. - // (Caveat: for now, users may still run exec() against the - // __main__ module's dict, though that isn't advisable.) - int running; - // This is a cached reference to the __dict__ of the entered - // interpreter's __main__ module. It is looked up when at the - // beginning of the session as a convenience. - PyObject *main_ns; - - // This is set if the interpreter is entered and raised an exception - // that needs to be handled in some special way during exit. - _PyXI_errcode *error_override; - // This is set if exit captured an exception to propagate. - _PyXI_error *error; - - // -- pre-allocated memory -- - _PyXI_error _error; - _PyXI_errcode _error_override; -}; +typedef struct xi_session _PyXI_session; + +PyAPI_FUNC(_PyXI_session *) _PyXI_NewSession(void); +PyAPI_FUNC(void) _PyXI_FreeSession(_PyXI_session *); PyAPI_FUNC(int) _PyXI_Enter( _PyXI_session *session, @@ -379,6 +356,8 @@ PyAPI_FUNC(int) _PyXI_Enter( PyObject *nsupdates); PyAPI_FUNC(void) _PyXI_Exit(_PyXI_session *session); +PyAPI_FUNC(PyObject *) _PyXI_GetMainNamespace(_PyXI_session *); + PyAPI_FUNC(PyObject *) _PyXI_ApplyCapturedException(_PyXI_session *session); PyAPI_FUNC(int) _PyXI_HasCapturedException(_PyXI_session *session); diff --git a/Include/internal/pycore_crossinterp_data_registry.h b/Include/internal/pycore_crossinterp_data_registry.h index 8f4bcb948e5a45..fbb4cad5cac32e 100644 --- a/Include/internal/pycore_crossinterp_data_registry.h +++ b/Include/internal/pycore_crossinterp_data_registry.h @@ -17,7 +17,7 @@ typedef struct _xid_regitem { /* This is NULL for builtin types. */ PyObject *weakref; size_t refcount; - xidatafunc getdata; + _PyXIData_getdata_t getdata; } _PyXIData_regitem_t; typedef struct { @@ -30,7 +30,7 @@ typedef struct { PyAPI_FUNC(int) _PyXIData_RegisterClass( PyThreadState *, PyTypeObject *, - xidatafunc); + _PyXIData_getdata_t); PyAPI_FUNC(int) _PyXIData_UnregisterClass( PyThreadState *, PyTypeObject *); diff --git a/Include/internal/pycore_pyerrors.h b/Include/internal/pycore_pyerrors.h index f357b88e220e6e..2c2048f7e1272a 100644 --- a/Include/internal/pycore_pyerrors.h +++ b/Include/internal/pycore_pyerrors.h @@ -94,13 +94,13 @@ extern void _PyErr_Fetch( PyObject **value, PyObject **traceback); -extern PyObject* _PyErr_GetRaisedException(PyThreadState *tstate); +PyAPI_FUNC(PyObject*) _PyErr_GetRaisedException(PyThreadState *tstate); PyAPI_FUNC(int) _PyErr_ExceptionMatches( PyThreadState *tstate, PyObject *exc); -extern void _PyErr_SetRaisedException(PyThreadState *tstate, PyObject *exc); +PyAPI_FUNC(void) _PyErr_SetRaisedException(PyThreadState *tstate, PyObject *exc); extern void _PyErr_Restore( PyThreadState *tstate, diff --git a/Lib/concurrent/futures/interpreter.py b/Lib/concurrent/futures/interpreter.py index d17688dc9d7346..a2c4fbfd3fb831 100644 --- a/Lib/concurrent/futures/interpreter.py +++ b/Lib/concurrent/futures/interpreter.py @@ -36,9 +36,6 @@ def __str__(self): """.strip()) -UNBOUND = 2 # error; this should not happen. - - class WorkerContext(_thread.WorkerContext): @classmethod @@ -47,23 +44,13 @@ def resolve_task(fn, args, kwargs): if isinstance(fn, str): # XXX Circle back to this later. raise TypeError('scripts not supported') - if args or kwargs: - raise ValueError(f'a script does not take args or kwargs, got {args!r} and {kwargs!r}') - data = textwrap.dedent(fn) - kind = 'script' - # Make sure the script compiles. - # Ideally we wouldn't throw away the resulting code - # object. However, there isn't much to be done until - # code objects are shareable and/or we do a better job - # of supporting code objects in _interpreters.exec(). - compile(data, '', 'exec') else: # Functions defined in the __main__ module can't be pickled, # so they can't be used here. In the future, we could possibly # borrow from multiprocessing to work around this. - data = pickle.dumps((fn, args, kwargs)) - kind = 'function' - return (data, kind) + task = (fn, args, kwargs) + data = pickle.dumps(task) + return data if initializer is not None: try: @@ -86,24 +73,20 @@ def _capture_exc(cls, resultsid): except BaseException as exc: # Send the captured exception out on the results queue, # but still leave it unhandled for the interpreter to handle. - err = pickle.dumps(exc) - _interpqueues.put(resultsid, (None, err), 1, UNBOUND) + _interpqueues.put(resultsid, (None, exc)) raise # re-raise @classmethod def _send_script_result(cls, resultsid): - _interpqueues.put(resultsid, (None, None), 0, UNBOUND) + _interpqueues.put(resultsid, (None, None)) @classmethod def _call(cls, func, args, kwargs, resultsid): with cls._capture_exc(resultsid): res = func(*args or (), **kwargs or {}) # Send the result back. - try: - _interpqueues.put(resultsid, (res, None), 0, UNBOUND) - except _interpreters.NotShareableError: - res = pickle.dumps(res) - _interpqueues.put(resultsid, (res, None), 1, UNBOUND) + with cls._capture_exc(resultsid): + _interpqueues.put(resultsid, (res, None)) @classmethod def _call_pickled(cls, pickled, resultsid): @@ -134,8 +117,7 @@ def initialize(self): _interpreters.incref(self.interpid) maxsize = 0 - fmt = 0 - self.resultsid = _interpqueues.create(maxsize, fmt, UNBOUND) + self.resultsid = _interpqueues.create(maxsize) self._exec(f'from {__name__} import WorkerContext') @@ -166,17 +148,8 @@ def finalize(self): pass def run(self, task): - data, kind = task - if kind == 'script': - raise NotImplementedError('script kind disabled') - script = f""" -with WorkerContext._capture_exc({self.resultsid}): -{textwrap.indent(data, ' ')} -WorkerContext._send_script_result({self.resultsid})""" - elif kind == 'function': - script = f'WorkerContext._call_pickled({data!r}, {self.resultsid})' - else: - raise NotImplementedError(kind) + data = task + script = f'WorkerContext._call_pickled({data!r}, {self.resultsid})' try: self._exec(script) @@ -199,15 +172,13 @@ def run(self, task): continue else: break - (res, excdata), pickled, unboundop = obj + (res, exc), unboundop = obj assert unboundop is None, unboundop - if excdata is not None: + if exc is not None: assert res is None, res - assert pickled assert exc_wrapper is not None - exc = pickle.loads(excdata) raise exc from exc_wrapper - return pickle.loads(res) if pickled else res + return res class BrokenInterpreterPool(_thread.BrokenThreadPool): diff --git a/Lib/test/support/interpreters/channels.py b/Lib/test/support/interpreters/channels.py index d2bd93d77f7169..3b6e0f0effd969 100644 --- a/Lib/test/support/interpreters/channels.py +++ b/Lib/test/support/interpreters/channels.py @@ -55,15 +55,23 @@ def create(*, unbounditems=UNBOUND): """ unbound = _serialize_unbound(unbounditems) unboundop, = unbound - cid = _channels.create(unboundop) - recv, send = RecvChannel(cid), SendChannel(cid, _unbound=unbound) + cid = _channels.create(unboundop, -1) + recv, send = RecvChannel(cid), SendChannel(cid) + send._set_unbound(unboundop, unbounditems) return recv, send def list_all(): """Return a list of (recv, send) for all open channels.""" - return [(RecvChannel(cid), SendChannel(cid, _unbound=unbound)) - for cid, unbound in _channels.list_all()] + channels = [] + for cid, unboundop, _ in _channels.list_all(): + chan = _, send = RecvChannel(cid), SendChannel(cid) + if not hasattr(send, '_unboundop'): + send._set_unbound(unboundop) + else: + assert send._unbound[0] == unboundop + channels.append(chan) + return channels class _ChannelEnd: @@ -175,16 +183,33 @@ class SendChannel(_ChannelEnd): _end = 'send' - def __new__(cls, cid, *, _unbound=None): - if _unbound is None: - try: - op = _channels.get_channel_defaults(cid) - _unbound = (op,) - except ChannelNotFoundError: - _unbound = _serialize_unbound(UNBOUND) - self = super().__new__(cls, cid) - self._unbound = _unbound - return self +# def __new__(cls, cid, *, _unbound=None): +# if _unbound is None: +# try: +# op = _channels.get_channel_defaults(cid) +# _unbound = (op,) +# except ChannelNotFoundError: +# _unbound = _serialize_unbound(UNBOUND) +# self = super().__new__(cls, cid) +# self._unbound = _unbound +# return self + + def _set_unbound(self, op, items=None): + assert not hasattr(self, '_unbound') + if items is None: + items = _resolve_unbound(op) + unbound = (op, items) + self._unbound = unbound + return unbound + + @property + def unbounditems(self): + try: + _, items = self._unbound + except AttributeError: + op, _ = _channels.get_queue_defaults(self._id) + _, items = self._set_unbound(op) + return items @property def is_closed(self): @@ -192,61 +217,61 @@ def is_closed(self): return info.closed or info.closing def send(self, obj, timeout=None, *, - unbound=None, + unbounditems=None, ): """Send the object (i.e. its data) to the channel's receiving end. This blocks until the object is received. """ - if unbound is None: - unboundop, = self._unbound + if unbounditems is None: + unboundop = -1 else: - unboundop, = _serialize_unbound(unbound) + unboundop, = _serialize_unbound(unbounditems) _channels.send(self._id, obj, unboundop, timeout=timeout, blocking=True) def send_nowait(self, obj, *, - unbound=None, + unbounditems=None, ): """Send the object to the channel's receiving end. If the object is immediately received then return True (else False). Otherwise this is the same as send(). """ - if unbound is None: - unboundop, = self._unbound + if unbounditems is None: + unboundop = -1 else: - unboundop, = _serialize_unbound(unbound) + unboundop, = _serialize_unbound(unbounditems) # XXX Note that at the moment channel_send() only ever returns # None. This should be fixed when channel_send_wait() is added. # See bpo-32604 and gh-19829. return _channels.send(self._id, obj, unboundop, blocking=False) def send_buffer(self, obj, timeout=None, *, - unbound=None, + unbounditems=None, ): """Send the object's buffer to the channel's receiving end. This blocks until the object is received. """ - if unbound is None: - unboundop, = self._unbound + if unbounditems is None: + unboundop = -1 else: - unboundop, = _serialize_unbound(unbound) + unboundop, = _serialize_unbound(unbounditems) _channels.send_buffer(self._id, obj, unboundop, timeout=timeout, blocking=True) def send_buffer_nowait(self, obj, *, - unbound=None, + unbounditems=None, ): """Send the object's buffer to the channel's receiving end. If the object is immediately received then return True (else False). Otherwise this is the same as send(). """ - if unbound is None: - unboundop, = self._unbound + if unbounditems is None: + unboundop = -1 else: - unboundop, = _serialize_unbound(unbound) + unboundop, = _serialize_unbound(unbounditems) return _channels.send_buffer(self._id, obj, unboundop, blocking=False) def close(self): diff --git a/Lib/test/support/interpreters/queues.py b/Lib/test/support/interpreters/queues.py index deb8e8613af731..d6a3197d9e0e26 100644 --- a/Lib/test/support/interpreters/queues.py +++ b/Lib/test/support/interpreters/queues.py @@ -63,29 +63,34 @@ def _resolve_unbound(flag): return resolved -def create(maxsize=0, *, syncobj=False, unbounditems=UNBOUND): +def create(maxsize=0, *, unbounditems=UNBOUND): """Return a new cross-interpreter queue. The queue may be used to pass data safely between interpreters. - "syncobj" sets the default for Queue.put() - and Queue.put_nowait(). - - "unbounditems" likewise sets the default. See Queue.put() for + "unbounditems" sets the default for Queue.put(); see that method for supported values. The default value is UNBOUND, which replaces the unbound item. """ - fmt = _SHARED_ONLY if syncobj else _PICKLED unbound = _serialize_unbound(unbounditems) unboundop, = unbound - qid = _queues.create(maxsize, fmt, unboundop) - return Queue(qid, _fmt=fmt, _unbound=unbound) + qid = _queues.create(maxsize, unboundop, -1) + self = Queue(qid) + self._set_unbound(unboundop, unbounditems) + return self def list_all(): """Return a list of all open queues.""" - return [Queue(qid, _fmt=fmt, _unbound=(unboundop,)) - for qid, fmt, unboundop in _queues.list_all()] + queues = [] + for qid, unboundop, _ in _queues.list_all(): + self = Queue(qid) + if not hasattr(self, '_unbound'): + self._set_unbound(unboundop) + else: + assert self._unbound[0] == unboundop + queues.append(self) + return queues _known_queues = weakref.WeakValueDictionary() @@ -93,28 +98,17 @@ def list_all(): class Queue: """A cross-interpreter queue.""" - def __new__(cls, id, /, *, _fmt=None, _unbound=None): + def __new__(cls, id, /): # There is only one instance for any given ID. if isinstance(id, int): id = int(id) else: raise TypeError(f'id must be an int, got {id!r}') - if _fmt is None: - if _unbound is None: - _fmt, op = _queues.get_queue_defaults(id) - _unbound = (op,) - else: - _fmt, _ = _queues.get_queue_defaults(id) - elif _unbound is None: - _, op = _queues.get_queue_defaults(id) - _unbound = (op,) try: self = _known_queues[id] except KeyError: self = super().__new__(cls) self._id = id - self._fmt = _fmt - self._unbound = _unbound _known_queues[id] = self _queues.bind(id) return self @@ -143,10 +137,27 @@ def __getnewargs__(self): def __getstate__(self): return None + def _set_unbound(self, op, items=None): + assert not hasattr(self, '_unbound') + if items is None: + items = _resolve_unbound(op) + unbound = (op, items) + self._unbound = unbound + return unbound + @property def id(self): return self._id + @property + def unbounditems(self): + try: + _, items = self._unbound + except AttributeError: + op, _ = _queues.get_queue_defaults(self._id) + _, items = self._set_unbound(op) + return items + @property def maxsize(self): try: @@ -165,77 +176,56 @@ def qsize(self): return _queues.get_count(self._id) def put(self, obj, timeout=None, *, - syncobj=None, - unbound=None, + unbounditems=None, _delay=10 / 1000, # 10 milliseconds ): """Add the object to the queue. This blocks while the queue is full. - If "syncobj" is None (the default) then it uses the - queue's default, set with create_queue(). - - If "syncobj" is false then all objects are supported, - at the expense of worse performance. - - If "syncobj" is true then the object must be "shareable". - Examples of "shareable" objects include the builtin singletons, - str, and memoryview. One benefit is that such objects are - passed through the queue efficiently. - - The key difference, though, is conceptual: the corresponding - object returned from Queue.get() will be strictly equivalent - to the given obj. In other words, the two objects will be - effectively indistinguishable from each other, even if the - object is mutable. The received object may actually be the - same object, or a copy (immutable values only), or a proxy. - Regardless, the received object should be treated as though - the original has been shared directly, whether or not it - actually is. That's a slightly different and stronger promise - than just (initial) equality, which is all "syncobj=False" - can promise. - - "unbound" controls the behavior of Queue.get() for the given + For most objects, the object received through Queue.get() will + be a new one, equivalent to the original and not sharing any + actual underlying data. The notable exceptions include + cross-interpreter types (like Queue) and memoryview, where the + underlying data is actually shared. Furthermore, some types + can be sent through a queue more efficiently than others. This + group includes various immutable types like int, str, bytes, and + tuple (if the items are likewise efficiently shareable). See interpreters.is_shareable(). + + "unbounditems" controls the behavior of Queue.get() for the given object if the current interpreter (calling put()) is later destroyed. - If "unbound" is None (the default) then it uses the + If "unbounditems" is None (the default) then it uses the queue's default, set with create_queue(), which is usually UNBOUND. - If "unbound" is UNBOUND_ERROR then get() will raise an + If "unbounditems" is UNBOUND_ERROR then get() will raise an ItemInterpreterDestroyed exception if the original interpreter has been destroyed. This does not otherwise affect the queue; the next call to put() will work like normal, returning the next item in the queue. - If "unbound" is UNBOUND_REMOVE then the item will be removed + If "unbounditems" is UNBOUND_REMOVE then the item will be removed from the queue as soon as the original interpreter is destroyed. Be aware that this will introduce an imbalance between put() and get() calls. - If "unbound" is UNBOUND then it is returned by get() in place + If "unbounditems" is UNBOUND then it is returned by get() in place of the unbound item. """ - if syncobj is None: - fmt = self._fmt - else: - fmt = _SHARED_ONLY if syncobj else _PICKLED - if unbound is None: - unboundop, = self._unbound + if unbounditems is None: + unboundop = -1 else: - unboundop, = _serialize_unbound(unbound) + unboundop, = _serialize_unbound(unbounditems) if timeout is not None: timeout = int(timeout) if timeout < 0: raise ValueError(f'timeout value must be non-negative') end = time.time() + timeout - if fmt is _PICKLED: - obj = pickle.dumps(obj) while True: try: - _queues.put(self._id, obj, fmt, unboundop) + _queues.put(self._id, obj, unboundop) except QueueFull as exc: if timeout is not None and time.time() >= end: raise # re-raise @@ -243,18 +233,12 @@ def put(self, obj, timeout=None, *, else: break - def put_nowait(self, obj, *, syncobj=None, unbound=None): - if syncobj is None: - fmt = self._fmt + def put_nowait(self, obj, *, unbounditems=None): + if unbounditems is None: + unboundop = -1 else: - fmt = _SHARED_ONLY if syncobj else _PICKLED - if unbound is None: - unboundop, = self._unbound - else: - unboundop, = _serialize_unbound(unbound) - if fmt is _PICKLED: - obj = pickle.dumps(obj) - _queues.put(self._id, obj, fmt, unboundop) + unboundop, = _serialize_unbound(unbounditems) + _queues.put(self._id, obj, unboundop) def get(self, timeout=None, *, _delay=10 / 1000, # 10 milliseconds @@ -265,7 +249,7 @@ def get(self, timeout=None, *, If the next item's original interpreter has been destroyed then the "next object" is determined by the value of the - "unbound" argument to put(). + "unbounditems" argument to put(). """ if timeout is not None: timeout = int(timeout) @@ -274,7 +258,7 @@ def get(self, timeout=None, *, end = time.time() + timeout while True: try: - obj, fmt, unboundop = _queues.get(self._id) + obj, unboundop = _queues.get(self._id) except QueueEmpty as exc: if timeout is not None and time.time() >= end: raise # re-raise @@ -284,10 +268,6 @@ def get(self, timeout=None, *, if unboundop is not None: assert obj is None, repr(obj) return _resolve_unbound(unboundop) - if fmt == _PICKLED: - obj = pickle.loads(obj) - else: - assert fmt == _SHARED_ONLY return obj def get_nowait(self): @@ -297,16 +277,12 @@ def get_nowait(self): is the same as get(). """ try: - obj, fmt, unboundop = _queues.get(self._id) + obj, unboundop = _queues.get(self._id) except QueueEmpty as exc: raise # re-raise if unboundop is not None: assert obj is None, repr(obj) return _resolve_unbound(unboundop) - if fmt == _PICKLED: - obj = pickle.loads(obj) - else: - assert fmt == _SHARED_ONLY return obj diff --git a/Lib/test/test__interpchannels.py b/Lib/test/test__interpchannels.py index e4c1ad854514ed..88eee03a3de93a 100644 --- a/Lib/test/test__interpchannels.py +++ b/Lib/test/test__interpchannels.py @@ -247,7 +247,7 @@ def _run_action(cid, action, end, state): def clean_up_channels(): - for cid, _ in _channels.list_all(): + for cid, _, _ in _channels.list_all(): try: _channels.destroy(cid) except _channels.ChannelNotFoundError: @@ -373,11 +373,11 @@ def test_create_cid(self): self.assertIsInstance(cid, _channels.ChannelID) def test_sequential_ids(self): - before = [cid for cid, _ in _channels.list_all()] + before = [cid for cid, _, _ in _channels.list_all()] id1 = _channels.create(REPLACE) id2 = _channels.create(REPLACE) id3 = _channels.create(REPLACE) - after = [cid for cid, _ in _channels.list_all()] + after = [cid for cid, _, _ in _channels.list_all()] self.assertEqual(id2, int(id1) + 1) self.assertEqual(id3, int(id2) + 1) diff --git a/Lib/test/test__interpreters.py b/Lib/test/test__interpreters.py index 0c43f46300f67d..4553aeb52ac1ab 100644 --- a/Lib/test/test__interpreters.py +++ b/Lib/test/test__interpreters.py @@ -474,13 +474,15 @@ def setUp(self): def test_signatures(self): # See https://github.com/python/cpython/issues/126654 - msg = "expected 'shared' to be a dict" + msg = r'_interpreters.exec\(\) argument 3 must be dict, not int' with self.assertRaisesRegex(TypeError, msg): _interpreters.exec(self.id, 'a', 1) with self.assertRaisesRegex(TypeError, msg): _interpreters.exec(self.id, 'a', shared=1) + msg = r'_interpreters.run_string\(\) argument 3 must be dict, not int' with self.assertRaisesRegex(TypeError, msg): _interpreters.run_string(self.id, 'a', shared=1) + msg = r'_interpreters.run_func\(\) argument 3 must be dict, not int' with self.assertRaisesRegex(TypeError, msg): _interpreters.run_func(self.id, lambda: None, shared=1) @@ -952,7 +954,8 @@ def test_invalid_syntax(self): """) with self.subTest('script'): - self.assert_run_failed(SyntaxError, script) + with self.assertRaises(SyntaxError): + _interpreters.run_string(self.id, script) with self.subTest('module'): modname = 'spam_spam_spam' @@ -1019,12 +1022,19 @@ def script(): with open(w, 'w', encoding="utf-8") as spipe: with contextlib.redirect_stdout(spipe): print('it worked!', end='') + failed = None def f(): - _interpreters.set___main___attrs(self.id, dict(w=w)) - _interpreters.run_func(self.id, script) + nonlocal failed + try: + _interpreters.set___main___attrs(self.id, dict(w=w)) + _interpreters.run_func(self.id, script) + except Exception as exc: + failed = exc t = threading.Thread(target=f) t.start() t.join() + if failed: + raise Exception from failed with open(r, encoding="utf-8") as outfile: out = outfile.read() @@ -1053,12 +1063,9 @@ def test_closure(self): spam = True def script(): assert spam - with self.assertRaises(ValueError): _interpreters.run_func(self.id, script) - # XXX This hasn't been fixed yet. - @unittest.expectedFailure def test_return_value(self): def script(): return 'spam' diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py index cddacbc9970052..ceeef0a7659fb4 100644 --- a/Lib/test/test_crossinterp.py +++ b/Lib/test/test_crossinterp.py @@ -5,6 +5,7 @@ import sys import types import unittest +import warnings from test.support import import_helper @@ -16,13 +17,281 @@ from test import _crossinterp_definitions as defs -BUILTIN_TYPES = [o for _, o in __builtins__.items() - if isinstance(o, type)] -EXCEPTION_TYPES = [cls for cls in BUILTIN_TYPES +@contextlib.contextmanager +def ignore_byteswarning(): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=BytesWarning) + yield + + +# builtin types + +BUILTINS_TYPES = [o for _, o in __builtins__.items() if isinstance(o, type)] +EXCEPTION_TYPES = [cls for cls in BUILTINS_TYPES if issubclass(cls, BaseException)] OTHER_TYPES = [o for n, o in vars(types).items() if (isinstance(o, type) and - n not in ('DynamicClassAttribute', '_GeneratorWrapper'))] + n not in ('DynamicClassAttribute', '_GeneratorWrapper'))] +BUILTIN_TYPES = [ + *BUILTINS_TYPES, + *OTHER_TYPES, +] + +# builtin exceptions + +try: + raise Exception +except Exception as exc: + CAUGHT = exc +EXCEPTIONS_WITH_SPECIAL_SIG = { + BaseExceptionGroup: (lambda msg: (msg, [CAUGHT])), + ExceptionGroup: (lambda msg: (msg, [CAUGHT])), + UnicodeError: (lambda msg: (None, msg, None, None, None)), + UnicodeEncodeError: (lambda msg: ('utf-8', '', 1, 3, msg)), + UnicodeDecodeError: (lambda msg: ('utf-8', b'', 1, 3, msg)), + UnicodeTranslateError: (lambda msg: ('', 1, 3, msg)), +} +BUILTIN_EXCEPTIONS = [ + *(cls(*sig('error!')) for cls, sig in EXCEPTIONS_WITH_SPECIAL_SIG.items()), + *(cls('error!') for cls in EXCEPTION_TYPES + if cls not in EXCEPTIONS_WITH_SPECIAL_SIG), +] + +# other builtin objects + +METHOD = defs.SpamOkay().okay +BUILTIN_METHOD = [].append +METHOD_DESCRIPTOR_WRAPPER = str.join +METHOD_WRAPPER = object().__str__ +WRAPPER_DESCRIPTOR = object.__init__ +BUILTIN_WRAPPERS = { + METHOD: types.MethodType, + BUILTIN_METHOD: types.BuiltinMethodType, + dict.__dict__['fromkeys']: types.ClassMethodDescriptorType, + types.FunctionType.__code__: types.GetSetDescriptorType, + types.FunctionType.__globals__: types.MemberDescriptorType, + METHOD_DESCRIPTOR_WRAPPER: types.MethodDescriptorType, + METHOD_WRAPPER: types.MethodWrapperType, + WRAPPER_DESCRIPTOR: types.WrapperDescriptorType, + staticmethod(defs.SpamOkay.okay): None, + classmethod(defs.SpamOkay.okay): None, + property(defs.SpamOkay.okay): None, +} +BUILTIN_FUNCTIONS = [ + # types.BuiltinFunctionType + len, + sys.is_finalizing, + sys.exit, + _testinternalcapi.get_crossinterp_data, +] +assert 'emptymod' not in sys.modules +with import_helper.ready_to_import('emptymod', ''): + import emptymod as EMPTYMOD +MODULES = [ + sys, + defs, + unittest, + EMPTYMOD, +] +OBJECT = object() +EXCEPTION = Exception() +LAMBDA = (lambda: None) +BUILTIN_SIMPLE = [ + OBJECT, + # singletons + None, + True, + False, + Ellipsis, + NotImplemented, + # bytes + *(i.to_bytes(2, 'little', signed=True) + for i in range(-1, 258)), + # str + 'hello world', + '你好世界', + '', + # int + sys.maxsize + 1, + sys.maxsize, + -sys.maxsize - 1, + -sys.maxsize - 2, + *range(-1, 258), + 2**1000, + # float + 0.0, + 1.1, + -1.0, + 0.12345678, + -0.12345678, +] +TUPLE_EXCEPTION = (0, 1.0, EXCEPTION) +TUPLE_OBJECT = (0, 1.0, OBJECT) +TUPLE_NESTED_EXCEPTION = (0, 1.0, (EXCEPTION,)) +TUPLE_NESTED_OBJECT = (0, 1.0, (OBJECT,)) +MEMORYVIEW_EMPTY = memoryview(b'') +MEMORYVIEW_NOT_EMPTY = memoryview(b'spam'*42) +MAPPING_PROXY_EMPTY = types.MappingProxyType({}) +BUILTIN_CONTAINERS = [ + # tuple (flat) + (), + (1,), + ("hello", "world", ), + (1, True, "hello"), + TUPLE_EXCEPTION, + TUPLE_OBJECT, + # tuple (nested) + ((1,),), + ((1, 2), (3, 4)), + ((1, 2), (3, 4), (5, 6)), + TUPLE_NESTED_EXCEPTION, + TUPLE_NESTED_OBJECT, + # buffer + MEMORYVIEW_EMPTY, + MEMORYVIEW_NOT_EMPTY, + # list + [], + [1, 2, 3], + [[1], (2,), {3: 4}], + # dict + {}, + {1: 7, 2: 8, 3: 9}, + {1: [1], 2: (2,), 3: {3: 4}}, + # set + set(), + {1, 2, 3}, + {frozenset({1}), (2,)}, + # frozenset + frozenset([]), + frozenset({frozenset({1}), (2,)}), + # bytearray + bytearray(b''), + # other + MAPPING_PROXY_EMPTY, + types.SimpleNamespace(), +] +ns = {} +exec(""" +try: + raise Exception +except Exception as exc: + TRACEBACK = exc.__traceback__ + FRAME = TRACEBACK.tb_frame +""", ns, ns) +BUILTIN_OTHER = [ + # types.CellType + types.CellType(), + # types.FrameType + ns['FRAME'], + # types.TracebackType + ns['TRACEBACK'], +] +del ns + +# user-defined objects + +USER_TOP_INSTANCES = [c(*a) for c, a in defs.TOP_CLASSES.items()] +USER_NESTED_INSTANCES = [c(*a) for c, a in defs.NESTED_CLASSES.items()] +USER_INSTANCES = [ + *USER_TOP_INSTANCES, + *USER_NESTED_INSTANCES, +] +USER_EXCEPTIONS = [ + defs.MimimalError('error!'), +] + +# shareable objects + +TUPLES_WITHOUT_EQUALITY = [ + TUPLE_EXCEPTION, + TUPLE_OBJECT, + TUPLE_NESTED_EXCEPTION, + TUPLE_NESTED_OBJECT, +] +_UNSHAREABLE_SIMPLE = [ + Ellipsis, + NotImplemented, + OBJECT, + sys.maxsize + 1, + -sys.maxsize - 2, + 2**1000, +] +with ignore_byteswarning(): + _SHAREABLE_SIMPLE = [o for o in BUILTIN_SIMPLE + if o not in _UNSHAREABLE_SIMPLE] + _SHAREABLE_CONTAINERS = [ + *(o for o in BUILTIN_CONTAINERS if type(o) is memoryview), + *(o for o in BUILTIN_CONTAINERS + if type(o) is tuple and o not in TUPLES_WITHOUT_EQUALITY), + ] + _UNSHAREABLE_CONTAINERS = [o for o in BUILTIN_CONTAINERS + if o not in _SHAREABLE_CONTAINERS] +SHAREABLE = [ + *_SHAREABLE_SIMPLE, + *_SHAREABLE_CONTAINERS, +] +NOT_SHAREABLE = [ + *_UNSHAREABLE_SIMPLE, + *_UNSHAREABLE_CONTAINERS, + *BUILTIN_TYPES, + *BUILTIN_WRAPPERS, + *BUILTIN_EXCEPTIONS, + *BUILTIN_FUNCTIONS, + *MODULES, + *BUILTIN_OTHER, + # types.CodeType + *(f.__code__ for f in defs.FUNCTIONS), + *(f.__code__ for f in defs.FUNCTION_LIKE), + # types.FunctionType + *defs.FUNCTIONS, + defs.SpamOkay.okay, + LAMBDA, + *defs.FUNCTION_LIKE, + # coroutines and generators + *defs.FUNCTION_LIKE_APPLIED, + # user classes + *defs.CLASSES, + *USER_INSTANCES, + # user exceptions + *USER_EXCEPTIONS, +] + +# pickleable objects + +PICKLEABLE = [ + *BUILTIN_SIMPLE, + *(o for o in BUILTIN_CONTAINERS if o not in [ + MEMORYVIEW_EMPTY, + MEMORYVIEW_NOT_EMPTY, + MAPPING_PROXY_EMPTY, + ] or type(o) is dict), + *BUILTINS_TYPES, + *BUILTIN_EXCEPTIONS, + *BUILTIN_FUNCTIONS, + *defs.TOP_FUNCTIONS, + defs.SpamOkay.okay, + *defs.FUNCTION_LIKE, + *defs.TOP_CLASSES, + *USER_TOP_INSTANCES, + *USER_EXCEPTIONS, + # from OTHER_TYPES + types.NoneType, + types.EllipsisType, + types.NotImplementedType, + types.GenericAlias, + types.UnionType, + types.SimpleNamespace, + # from BUILTIN_WRAPPERS + METHOD, + BUILTIN_METHOD, + METHOD_DESCRIPTOR_WRAPPER, + METHOD_WRAPPER, + WRAPPER_DESCRIPTOR, +] +assert not any(isinstance(o, types.MappingProxyType) for o in PICKLEABLE) + + +# helpers DEFS = defs with open(code_defs.__file__) as infile: @@ -111,6 +380,78 @@ class _GetXIDataTests(unittest.TestCase): MODE = None + def assert_functions_equal(self, func1, func2): + assert type(func1) is types.FunctionType, repr(func1) + assert type(func2) is types.FunctionType, repr(func2) + self.assertEqual(func1.__name__, func2.__name__) + self.assertEqual(func1.__code__, func2.__code__) + self.assertEqual(func1.__defaults__, func2.__defaults__) + self.assertEqual(func1.__kwdefaults__, func2.__kwdefaults__) + # We don't worry about __globals__ for now. + + def assert_exc_args_equal(self, exc1, exc2): + args1 = exc1.args + args2 = exc2.args + if isinstance(exc1, ExceptionGroup): + self.assertIs(type(args1), type(args2)) + self.assertEqual(len(args1), 2) + self.assertEqual(len(args1), len(args2)) + self.assertEqual(args1[0], args2[0]) + group1 = args1[1] + group2 = args2[1] + self.assertEqual(len(group1), len(group2)) + for grouped1, grouped2 in zip(group1, group2): + # Currently the "extra" attrs are not preserved + # (via __reduce__). + self.assertIs(type(exc1), type(exc2)) + #self.assert_exc_equal(grouped1, grouped2) + else: + self.assertEqual(args1, args2) + + def assert_exc_equal(self, exc1, exc2): + self.assertIs(type(exc1), type(exc2)) + + if type(exc1).__eq__ is not object.__eq__: + self.assertEqual(exc1, exc2) + + self.assert_exc_args_equal(exc1, exc2) + # XXX For now we do not preserve tracebacks. + if exc1.__traceback__ is not None: + self.assertEqual(exc1.__traceback__, exc2.__traceback__) + self.assertEqual( + getattr(exc1, '__notes__', None), + getattr(exc2, '__notes__', None), + ) + # We assume there are no cycles. + if exc1.__cause__ is None: + self.assertIs(exc1.__cause__, exc2.__cause__) + else: + self.assert_exc_equal(exc1.__cause__, exc2.__cause__) + if exc1.__context__ is None: + self.assertIs(exc1.__context__, exc2.__context__) + else: + self.assert_exc_equal(exc1.__context__, exc2.__context__) + + def assert_equal_or_equalish(self, obj, expected): + cls = type(expected) + if cls.__eq__ is not object.__eq__: +# assert cls not in (types.MethodType, types.BuiltinMethodType, types.MethodWrapperType), cls + self.assertEqual(obj, expected) + elif cls is types.FunctionType: + self.assert_functions_equal(obj, expected) + elif isinstance(expected, BaseException): + self.assert_exc_equal(obj, expected) + elif cls is types.MethodType: + raise NotImplementedError(cls) + elif cls is types.BuiltinMethodType: + raise NotImplementedError(cls) + elif cls is types.MethodWrapperType: + raise NotImplementedError(cls) + elif cls.__bases__ == (object,): + self.assertEqual(obj.__dict__, expected.__dict__) + else: + raise NotImplementedError(cls) + def get_xidata(self, obj, *, mode=None): mode = self._resolve_mode(mode) return _testinternalcapi.get_crossinterp_data(obj, mode) @@ -126,35 +467,37 @@ def _get_roundtrip(self, obj, mode): def assert_roundtrip_identical(self, values, *, mode=None): mode = self._resolve_mode(mode) for obj in values: - with self.subTest(obj): + with self.subTest(repr(obj)): got = self._get_roundtrip(obj, mode) self.assertIs(got, obj) def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None): mode = self._resolve_mode(mode) for obj in values: - with self.subTest(obj): + with self.subTest(repr(obj)): got = self._get_roundtrip(obj, mode) - self.assertEqual(got, obj) + if got is obj: + continue self.assertIs(type(got), type(obj) if expecttype is None else expecttype) + self.assert_equal_or_equalish(got, obj) def assert_roundtrip_equal_not_identical(self, values, *, mode=None, expecttype=None): mode = self._resolve_mode(mode) for obj in values: - with self.subTest(obj): + with self.subTest(repr(obj)): got = self._get_roundtrip(obj, mode) self.assertIsNot(got, obj) self.assertIs(type(got), type(obj) if expecttype is None else expecttype) - self.assertEqual(got, obj) + self.assert_equal_or_equalish(got, obj) def assert_roundtrip_not_equal(self, values, *, mode=None, expecttype=None): mode = self._resolve_mode(mode) for obj in values: - with self.subTest(obj): + with self.subTest(repr(obj)): got = self._get_roundtrip(obj, mode) self.assertIsNot(got, obj) self.assertIs(type(got), @@ -164,7 +507,7 @@ def assert_roundtrip_not_equal(self, values, *, def assert_not_shareable(self, values, exctype=None, *, mode=None): mode = self._resolve_mode(mode) for obj in values: - with self.subTest(obj): + with self.subTest(repr(obj)): with self.assertRaises(NotShareableError) as cm: _testinternalcapi.get_crossinterp_data(obj, mode) if exctype is not None: @@ -182,49 +525,26 @@ class PickleTests(_GetXIDataTests): MODE = 'pickle' def test_shareable(self): - self.assert_roundtrip_equal([ - # singletons - None, - True, - False, - # bytes - *(i.to_bytes(2, 'little', signed=True) - for i in range(-1, 258)), - # str - 'hello world', - '你好世界', - '', - # int - sys.maxsize, - -sys.maxsize - 1, - *range(-1, 258), - # float - 0.0, - 1.1, - -1.0, - 0.12345678, - -0.12345678, - # tuple - (), - (1,), - ("hello", "world", ), - (1, True, "hello"), - ((1,),), - ((1, 2), (3, 4)), - ((1, 2), (3, 4), (5, 6)), - ]) - # not shareable using xidata - self.assert_roundtrip_equal([ - # int - sys.maxsize + 1, - -sys.maxsize - 2, - 2**1000, - # tuple - (0, 1.0, []), - (0, 1.0, {}), - (0, 1.0, ([],)), - (0, 1.0, ({},)), - ]) + with ignore_byteswarning(): + for obj in SHAREABLE: + if obj in PICKLEABLE: + self.assert_roundtrip_equal([obj]) + else: + self.assert_not_shareable([obj]) + + def test_not_shareable(self): + with ignore_byteswarning(): + for obj in NOT_SHAREABLE: + if type(obj) is types.MappingProxyType: + self.assert_not_shareable([obj]) + elif obj in PICKLEABLE: + with self.subTest(repr(obj)): + # We don't worry about checking the actual value. + # The other tests should cover that well enough. + got = self.get_roundtrip(obj) + self.assertIs(type(got), type(obj)) + else: + self.assert_not_shareable([obj]) def test_list(self): self.assert_roundtrip_equal_not_identical([ @@ -266,7 +586,7 @@ def assert_class_defs_same(self, defs): if cls not in defs.CLASSES_WITHOUT_EQUALITY: continue instances.append(cls(*args)) - self.assert_roundtrip_not_equal(instances) + self.assert_roundtrip_equal(instances) def assert_class_defs_other_pickle(self, defs, mod): # Pickle relative to a different module than the original. @@ -286,7 +606,7 @@ def assert_class_defs_other_unpickle(self, defs, mod, *, fail=False): instances = [] for cls, args in defs.TOP_CLASSES.items(): - with self.subTest(cls): + with self.subTest(repr(cls)): setattr(mod, cls.__name__, cls) xid = self.get_xidata(cls) inst = cls(*args) @@ -295,7 +615,7 @@ def assert_class_defs_other_unpickle(self, defs, mod, *, fail=False): (cls, xid, inst, instxid)) for cls, xid, inst, instxid in instances: - with self.subTest(cls): + with self.subTest(repr(cls)): delattr(mod, cls.__name__) if fail: with self.assertRaises(NotShareableError): @@ -505,7 +825,7 @@ def test_nested_function(self): # exceptions def test_user_exception_normal(self): - self.assert_roundtrip_not_equal([ + self.assert_roundtrip_equal([ defs.MimimalError('error!'), ]) self.assert_roundtrip_equal_not_identical([ @@ -531,7 +851,7 @@ def test_builtin_exception(self): args = special.get(cls) or (msg,) exceptions.append(cls(*args)) - self.assert_roundtrip_not_equal(exceptions) + self.assert_roundtrip_equal(exceptions) class MarshalTests(_GetXIDataTests): @@ -576,7 +896,7 @@ def test_simple_builtin_objects(self): '', ]) self.assert_not_shareable([ - object(), + OBJECT, types.SimpleNamespace(), ]) @@ -647,10 +967,7 @@ def test_builtin_type(self): shareable = [ StopIteration, ] - types = [ - *BUILTIN_TYPES, - *OTHER_TYPES, - ] + types = BUILTIN_TYPES self.assert_not_shareable(cls for cls in types if cls not in shareable) self.assert_roundtrip_identical(cls for cls in types @@ -763,7 +1080,7 @@ class ShareableFuncTests(_GetXIDataTests): MODE = 'func' def test_stateless(self): - self.assert_roundtrip_not_equal([ + self.assert_roundtrip_equal([ *defs.STATELESS_FUNCTIONS, # Generators can be stateless too. *defs.FUNCTION_LIKE, @@ -912,10 +1229,49 @@ def test_impure_script_function(self): ], expecttype=types.CodeType) +class ShareableFallbackTests(_GetXIDataTests): + + MODE = 'fallback' + + def test_shareable(self): + self.assert_roundtrip_equal(SHAREABLE) + + def test_not_shareable(self): + okay = [ + *PICKLEABLE, + *defs.STATELESS_FUNCTIONS, + LAMBDA, + ] + ignored = [ + *TUPLES_WITHOUT_EQUALITY, + OBJECT, + METHOD, + BUILTIN_METHOD, + METHOD_WRAPPER, + ] + with ignore_byteswarning(): + self.assert_roundtrip_equal([ + *(o for o in NOT_SHAREABLE + if o in okay and o not in ignored + and o is not MAPPING_PROXY_EMPTY), + ]) + self.assert_roundtrip_not_equal([ + *(o for o in NOT_SHAREABLE + if o in ignored and o is not MAPPING_PROXY_EMPTY), + ]) + self.assert_not_shareable([ + *(o for o in NOT_SHAREABLE if o not in okay), + MAPPING_PROXY_EMPTY, + ]) + + class ShareableTypeTests(_GetXIDataTests): MODE = 'xidata' + def test_shareable(self): + self.assert_roundtrip_equal(SHAREABLE) + def test_singletons(self): self.assert_roundtrip_identical([ None, @@ -983,8 +1339,8 @@ def test_tuple(self): def test_tuples_containing_non_shareable_types(self): non_shareables = [ - Exception(), - object(), + EXCEPTION, + OBJECT, ] for s in non_shareables: value = tuple([0, 1.0, s]) @@ -999,6 +1355,9 @@ def test_tuples_containing_non_shareable_types(self): # The rest are not shareable. + def test_not_shareable(self): + self.assert_not_shareable(NOT_SHAREABLE) + def test_object(self): self.assert_not_shareable([ object(), @@ -1015,12 +1374,12 @@ def test_function_object(self): for func in defs.FUNCTIONS: assert type(func) is types.FunctionType, func assert type(defs.SpamOkay.okay) is types.FunctionType, func - assert type(lambda: None) is types.LambdaType + assert type(LAMBDA) is types.LambdaType self.assert_not_shareable([ *defs.FUNCTIONS, defs.SpamOkay.okay, - (lambda: None), + LAMBDA, ]) def test_builtin_function(self): @@ -1085,10 +1444,7 @@ def test_class(self): self.assert_not_shareable(instances) def test_builtin_type(self): - self.assert_not_shareable([ - *BUILTIN_TYPES, - *OTHER_TYPES, - ]) + self.assert_not_shareable(BUILTIN_TYPES) def test_exception(self): self.assert_not_shareable([ @@ -1127,7 +1483,7 @@ def test_builtin_objects(self): """, ns, ns) self.assert_not_shareable([ - types.MappingProxyType({}), + MAPPING_PROXY_EMPTY, types.SimpleNamespace(), # types.CellType types.CellType(), diff --git a/Lib/test/test_interpreters/test_api.py b/Lib/test/test_interpreters/test_api.py index 66c7afce88f0d8..2ca5cac6083975 100644 --- a/Lib/test/test_interpreters/test_api.py +++ b/Lib/test/test_interpreters/test_api.py @@ -839,9 +839,16 @@ def test_bad_script(self): interp.exec(10) def test_bytes_for_script(self): + r, w = self.pipe() + RAN = b'R' + DONE = b'D' interp = interpreters.create() - with self.assertRaises(TypeError): - interp.exec(b'print("spam")') + interp.exec(f"""if True: + import os + os.write({w}, {RAN!r}) + """) + os.write(w, DONE) + self.assertEqual(os.read(r, 1), RAN) def test_with_background_threads_still_running(self): r_interp, w_interp = self.pipe() @@ -1010,8 +1017,8 @@ def test_call(self): for i, (callable, args, kwargs) in enumerate([ (call_func_noop, (), {}), - (call_func_return_shareable, (), {}), - (call_func_return_not_shareable, (), {}), + #(call_func_return_shareable, (), {}), + #(call_func_return_not_shareable, (), {}), (Spam.noop, (), {}), ]): with self.subTest(f'success case #{i+1}'): @@ -1036,6 +1043,8 @@ def test_call(self): (call_func_complex, ('custom', 'spam!'), {}), (call_func_complex, ('custom-inner', 'eggs!'), {}), (call_func_complex, ('???',), {'exc': ValueError('spam')}), + (call_func_return_shareable, (), {}), + (call_func_return_not_shareable, (), {}), ]): with self.subTest(f'invalid case #{i+1}'): with self.assertRaises(Exception): @@ -1051,8 +1060,8 @@ def test_call_in_thread(self): for i, (callable, args, kwargs) in enumerate([ (call_func_noop, (), {}), - (call_func_return_shareable, (), {}), - (call_func_return_not_shareable, (), {}), + #(call_func_return_shareable, (), {}), + #(call_func_return_not_shareable, (), {}), (Spam.noop, (), {}), ]): with self.subTest(f'success case #{i+1}'): @@ -1079,6 +1088,8 @@ def test_call_in_thread(self): (call_func_complex, ('custom', 'spam!'), {}), (call_func_complex, ('custom-inner', 'eggs!'), {}), (call_func_complex, ('???',), {'exc': ValueError('spam')}), + (call_func_return_shareable, (), {}), + (call_func_return_not_shareable, (), {}), ]): with self.subTest(f'invalid case #{i+1}'): if args or kwargs: @@ -1610,8 +1621,10 @@ def test_exec(self): def test_call(self): with self.subTest('no args'): interpid = _interpreters.create() - exc = _interpreters.call(interpid, call_func_return_shareable) - self.assertIs(exc, None) + with self.assertRaises(ValueError): + _interpreters.call(interpid, call_func_return_shareable) +# exc = _interpreters.call(interpid, call_func_return_shareable) +# self.assertIs(exc, None) with self.subTest('uncaught exception'): interpid = _interpreters.create() diff --git a/Lib/test/test_interpreters/test_channels.py b/Lib/test/test_interpreters/test_channels.py index eada18f99d04db..0c027b17cea68c 100644 --- a/Lib/test/test_interpreters/test_channels.py +++ b/Lib/test/test_interpreters/test_channels.py @@ -377,11 +377,11 @@ def common(rch, sch, unbound=None, presize=0): if not unbound: extraargs = '' elif unbound is channels.UNBOUND: - extraargs = ', unbound=channels.UNBOUND' + extraargs = ', unbounditems=channels.UNBOUND' elif unbound is channels.UNBOUND_ERROR: - extraargs = ', unbound=channels.UNBOUND_ERROR' + extraargs = ', unbounditems=channels.UNBOUND_ERROR' elif unbound is channels.UNBOUND_REMOVE: - extraargs = ', unbound=channels.UNBOUND_REMOVE' + extraargs = ', unbounditems=channels.UNBOUND_REMOVE' else: raise NotImplementedError(repr(unbound)) interp = interpreters.create() @@ -454,11 +454,11 @@ def common(rch, sch, unbound=None, presize=0): with self.assertRaises(channels.ChannelEmptyError): rch.recv_nowait() - sch.send_nowait(b'ham', unbound=channels.UNBOUND_REMOVE) + sch.send_nowait(b'ham', unbounditems=channels.UNBOUND_REMOVE) self.assertEqual(_channels.get_count(rch.id), 1) interp = common(rch, sch, channels.UNBOUND_REMOVE, 1) self.assertEqual(_channels.get_count(rch.id), 3) - sch.send_nowait(42, unbound=channels.UNBOUND_REMOVE) + sch.send_nowait(42, unbounditems=channels.UNBOUND_REMOVE) self.assertEqual(_channels.get_count(rch.id), 4) del interp self.assertEqual(_channels.get_count(rch.id), 2) @@ -484,11 +484,11 @@ def test_send_cleared_with_subinterpreter_mixed(self): _run_output(interp, dedent(f""" from test.support.interpreters import channels sch = channels.SendChannel({sch.id}) - sch.send_nowait(1, unbound=channels.UNBOUND) - sch.send_nowait(2, unbound=channels.UNBOUND_ERROR) + sch.send_nowait(1, unbounditems=channels.UNBOUND) + sch.send_nowait(2, unbounditems=channels.UNBOUND_ERROR) sch.send_nowait(3) - sch.send_nowait(4, unbound=channels.UNBOUND_REMOVE) - sch.send_nowait(5, unbound=channels.UNBOUND) + sch.send_nowait(4, unbounditems=channels.UNBOUND_REMOVE) + sch.send_nowait(5, unbounditems=channels.UNBOUND) """)) self.assertEqual(_channels.get_count(rch.id), 5) @@ -522,8 +522,8 @@ def test_send_cleared_with_subinterpreter_multiple(self): rch = channels.RecvChannel({rch.id}) sch = channels.SendChannel({sch.id}) obj1 = rch.recv() - sch.send_nowait(2, unbound=channels.UNBOUND) - sch.send_nowait(obj1, unbound=channels.UNBOUND_REMOVE) + sch.send_nowait(2, unbounditems=channels.UNBOUND) + sch.send_nowait(obj1, unbounditems=channels.UNBOUND_REMOVE) """)) _run_output(interp2, dedent(f""" from test.support.interpreters import channels @@ -535,21 +535,21 @@ def test_send_cleared_with_subinterpreter_multiple(self): self.assertEqual(_channels.get_count(rch.id), 0) sch.send_nowait(3) _run_output(interp1, dedent(""" - sch.send_nowait(4, unbound=channels.UNBOUND) + sch.send_nowait(4, unbounditems=channels.UNBOUND) # interp closed here - sch.send_nowait(5, unbound=channels.UNBOUND_REMOVE) - sch.send_nowait(6, unbound=channels.UNBOUND) + sch.send_nowait(5, unbounditems=channels.UNBOUND_REMOVE) + sch.send_nowait(6, unbounditems=channels.UNBOUND) """)) _run_output(interp2, dedent(""" - sch.send_nowait(7, unbound=channels.UNBOUND_ERROR) + sch.send_nowait(7, unbounditems=channels.UNBOUND_ERROR) # interp closed here - sch.send_nowait(obj1, unbound=channels.UNBOUND_ERROR) - sch.send_nowait(obj2, unbound=channels.UNBOUND_REMOVE) - sch.send_nowait(8, unbound=channels.UNBOUND) + sch.send_nowait(obj1, unbounditems=channels.UNBOUND_ERROR) + sch.send_nowait(obj2, unbounditems=channels.UNBOUND_REMOVE) + sch.send_nowait(8, unbounditems=channels.UNBOUND) """)) _run_output(interp1, dedent(""" - sch.send_nowait(9, unbound=channels.UNBOUND_REMOVE) - sch.send_nowait(10, unbound=channels.UNBOUND) + sch.send_nowait(9, unbounditems=channels.UNBOUND_REMOVE) + sch.send_nowait(10, unbounditems=channels.UNBOUND) """)) self.assertEqual(_channels.get_count(rch.id), 10) diff --git a/Lib/test/test_interpreters/test_queues.py b/Lib/test/test_interpreters/test_queues.py index 18f83d097eb360..64a2db1230d023 100644 --- a/Lib/test/test_interpreters/test_queues.py +++ b/Lib/test/test_interpreters/test_queues.py @@ -9,6 +9,7 @@ _queues = import_helper.import_module('_interpqueues') from test.support import interpreters from test.support.interpreters import queues, _crossinterp +import test._crossinterp_definitions as defs from .utils import _run_output, TestBase as _TestBase @@ -42,7 +43,7 @@ def test_highlevel_reloaded(self): importlib.reload(queues) def test_create_destroy(self): - qid = _queues.create(2, 0, REPLACE) + qid = _queues.create(2, REPLACE, -1) _queues.destroy(qid) self.assertEqual(get_num_queues(), 0) with self.assertRaises(queues.QueueNotFoundError): @@ -56,7 +57,7 @@ def test_not_destroyed(self): '-c', dedent(f""" import {_queues.__name__} as _queues - _queues.create(2, 0, {REPLACE}) + _queues.create(2, {REPLACE}, -1) """), ) self.assertEqual(stdout, '') @@ -67,13 +68,13 @@ def test_not_destroyed(self): def test_bind_release(self): with self.subTest('typical'): - qid = _queues.create(2, 0, REPLACE) + qid = _queues.create(2, REPLACE, -1) _queues.bind(qid) _queues.release(qid) self.assertEqual(get_num_queues(), 0) with self.subTest('bind too much'): - qid = _queues.create(2, 0, REPLACE) + qid = _queues.create(2, REPLACE, -1) _queues.bind(qid) _queues.bind(qid) _queues.release(qid) @@ -81,7 +82,7 @@ def test_bind_release(self): self.assertEqual(get_num_queues(), 0) with self.subTest('nested'): - qid = _queues.create(2, 0, REPLACE) + qid = _queues.create(2, REPLACE, -1) _queues.bind(qid) _queues.bind(qid) _queues.release(qid) @@ -89,7 +90,7 @@ def test_bind_release(self): self.assertEqual(get_num_queues(), 0) with self.subTest('release without binding'): - qid = _queues.create(2, 0, REPLACE) + qid = _queues.create(2, REPLACE, -1) with self.assertRaises(queues.QueueError): _queues.release(qid) @@ -132,13 +133,13 @@ def test_shareable(self): with self.subTest('same interpreter'): queue2 = queues.create() - queue1.put(queue2, syncobj=True) + queue1.put(queue2) queue3 = queue1.get() self.assertIs(queue3, queue2) with self.subTest('from current interpreter'): queue4 = queues.create() - queue1.put(queue4, syncobj=True) + queue1.put(queue4) out = _run_output(interp, dedent(""" queue4 = queue1.get() print(queue4.id) @@ -149,7 +150,7 @@ def test_shareable(self): with self.subTest('from subinterpreter'): out = _run_output(interp, dedent(""" queue5 = queues.create() - queue1.put(queue5, syncobj=True) + queue1.put(queue5) print(queue5.id) """)) qid = int(out) @@ -198,7 +199,7 @@ class TestQueueOps(TestBase): def test_empty(self): queue = queues.create() before = queue.empty() - queue.put(None, syncobj=True) + queue.put(None) during = queue.empty() queue.get() after = queue.empty() @@ -213,7 +214,7 @@ def test_full(self): queue = queues.create(3) for _ in range(3): actual.append(queue.full()) - queue.put(None, syncobj=True) + queue.put(None) actual.append(queue.full()) for _ in range(3): queue.get() @@ -227,16 +228,16 @@ def test_qsize(self): queue = queues.create() for _ in range(3): actual.append(queue.qsize()) - queue.put(None, syncobj=True) + queue.put(None) actual.append(queue.qsize()) queue.get() actual.append(queue.qsize()) - queue.put(None, syncobj=True) + queue.put(None) actual.append(queue.qsize()) for _ in range(3): queue.get() actual.append(queue.qsize()) - queue.put(None, syncobj=True) + queue.put(None) actual.append(queue.qsize()) queue.get() actual.append(queue.qsize()) @@ -245,70 +246,32 @@ def test_qsize(self): def test_put_get_main(self): expected = list(range(20)) - for syncobj in (True, False): - kwds = dict(syncobj=syncobj) - with self.subTest(f'syncobj={syncobj}'): - queue = queues.create() - for i in range(20): - queue.put(i, **kwds) - actual = [queue.get() for _ in range(20)] + queue = queues.create() + for i in range(20): + queue.put(i) + actual = [queue.get() for _ in range(20)] - self.assertEqual(actual, expected) + self.assertEqual(actual, expected) def test_put_timeout(self): - for syncobj in (True, False): - kwds = dict(syncobj=syncobj) - with self.subTest(f'syncobj={syncobj}'): - queue = queues.create(2) - queue.put(None, **kwds) - queue.put(None, **kwds) - with self.assertRaises(queues.QueueFull): - queue.put(None, timeout=0.1, **kwds) - queue.get() - queue.put(None, **kwds) + queue = queues.create(2) + queue.put(None) + queue.put(None) + with self.assertRaises(queues.QueueFull): + queue.put(None, timeout=0.1) + queue.get() + queue.put(None) def test_put_nowait(self): - for syncobj in (True, False): - kwds = dict(syncobj=syncobj) - with self.subTest(f'syncobj={syncobj}'): - queue = queues.create(2) - queue.put_nowait(None, **kwds) - queue.put_nowait(None, **kwds) - with self.assertRaises(queues.QueueFull): - queue.put_nowait(None, **kwds) - queue.get() - queue.put_nowait(None, **kwds) - - def test_put_syncobj(self): - for obj in [ - None, - True, - 10, - 'spam', - b'spam', - (0, 'a'), - ]: - with self.subTest(repr(obj)): - queue = queues.create() - - queue.put(obj, syncobj=True) - obj2 = queue.get() - self.assertEqual(obj2, obj) - - queue.put(obj, syncobj=True) - obj2 = queue.get_nowait() - self.assertEqual(obj2, obj) - - for obj in [ - [1, 2, 3], - {'a': 13, 'b': 17}, - ]: - with self.subTest(repr(obj)): - queue = queues.create() - with self.assertRaises(interpreters.NotShareableError): - queue.put(obj, syncobj=True) + queue = queues.create(2) + queue.put_nowait(None) + queue.put_nowait(None) + with self.assertRaises(queues.QueueFull): + queue.put_nowait(None) + queue.get() + queue.put_nowait(None) - def test_put_not_syncobj(self): + def test_put_full_fallback(self): for obj in [ None, True, @@ -323,11 +286,11 @@ def test_put_not_syncobj(self): with self.subTest(repr(obj)): queue = queues.create() - queue.put(obj, syncobj=False) + queue.put(obj) obj2 = queue.get() self.assertEqual(obj2, obj) - queue.put(obj, syncobj=False) + queue.put(obj) obj2 = queue.get_nowait() self.assertEqual(obj2, obj) @@ -341,24 +304,9 @@ def test_get_nowait(self): with self.assertRaises(queues.QueueEmpty): queue.get_nowait() - def test_put_get_default_syncobj(self): + def test_put_get_full_fallback(self): expected = list(range(20)) - queue = queues.create(syncobj=True) - for methname in ('get', 'get_nowait'): - with self.subTest(f'{methname}()'): - get = getattr(queue, methname) - for i in range(20): - queue.put(i) - actual = [get() for _ in range(20)] - self.assertEqual(actual, expected) - - obj = [1, 2, 3] # lists are not shareable - with self.assertRaises(interpreters.NotShareableError): - queue.put(obj) - - def test_put_get_default_not_syncobj(self): - expected = list(range(20)) - queue = queues.create(syncobj=False) + queue = queues.create() for methname in ('get', 'get_nowait'): with self.subTest(f'{methname}()'): get = getattr(queue, methname) @@ -384,7 +332,7 @@ def test_put_get_same_interpreter(self): with self.subTest(f'{methname}()'): interp.exec(dedent(f""" orig = b'spam' - queue.put(orig, syncobj=True) + queue.put(orig) obj = queue.{methname}() assert obj == orig, 'expected: obj == orig' assert obj is not orig, 'expected: obj is not orig' @@ -399,7 +347,7 @@ def test_put_get_different_interpreters(self): for methname in ('get', 'get_nowait'): with self.subTest(f'{methname}()'): obj1 = b'spam' - queue1.put(obj1, syncobj=True) + queue1.put(obj1) out = _run_output( interp, @@ -416,7 +364,7 @@ def test_put_get_different_interpreters(self): obj2 = b'eggs' print(id(obj2)) assert queue2.qsize() == 0, 'expected: queue2.qsize() == 0' - queue2.put(obj2, syncobj=True) + queue2.put(obj2) assert queue2.qsize() == 1, 'expected: queue2.qsize() == 1' """)) self.assertEqual(len(queues.list_all()), 2) @@ -433,11 +381,11 @@ def common(queue, unbound=None, presize=0): if not unbound: extraargs = '' elif unbound is queues.UNBOUND: - extraargs = ', unbound=queues.UNBOUND' + extraargs = ', unbounditems=queues.UNBOUND' elif unbound is queues.UNBOUND_ERROR: - extraargs = ', unbound=queues.UNBOUND_ERROR' + extraargs = ', unbounditems=queues.UNBOUND_ERROR' elif unbound is queues.UNBOUND_REMOVE: - extraargs = ', unbound=queues.UNBOUND_REMOVE' + extraargs = ', unbounditems=queues.UNBOUND_REMOVE' else: raise NotImplementedError(repr(unbound)) interp = interpreters.create() @@ -447,8 +395,8 @@ def common(queue, unbound=None, presize=0): queue = queues.Queue({queue.id}) obj1 = b'spam' obj2 = b'eggs' - queue.put(obj1, syncobj=True{extraargs}) - queue.put(obj2, syncobj=True{extraargs}) + queue.put(obj1{extraargs}) + queue.put(obj2{extraargs}) """)) self.assertEqual(queue.qsize(), presize + 2) @@ -501,11 +449,11 @@ def common(queue, unbound=None, presize=0): with self.assertRaises(queues.QueueEmpty): queue.get_nowait() - queue.put(b'ham', unbound=queues.UNBOUND_REMOVE) + queue.put(b'ham', unbounditems=queues.UNBOUND_REMOVE) self.assertEqual(queue.qsize(), 1) interp = common(queue, queues.UNBOUND_REMOVE, 1) self.assertEqual(queue.qsize(), 3) - queue.put(42, unbound=queues.UNBOUND_REMOVE) + queue.put(42, unbounditems=queues.UNBOUND_REMOVE) self.assertEqual(queue.qsize(), 4) del interp self.assertEqual(queue.qsize(), 2) @@ -523,11 +471,11 @@ def test_put_cleared_with_subinterpreter_mixed(self): _run_output(interp, dedent(f""" from test.support.interpreters import queues queue = queues.Queue({queue.id}) - queue.put(1, syncobj=True, unbound=queues.UNBOUND) - queue.put(2, syncobj=True, unbound=queues.UNBOUND_ERROR) - queue.put(3, syncobj=True) - queue.put(4, syncobj=True, unbound=queues.UNBOUND_REMOVE) - queue.put(5, syncobj=True, unbound=queues.UNBOUND) + queue.put(1, unbounditems=queues.UNBOUND) + queue.put(2, unbounditems=queues.UNBOUND_ERROR) + queue.put(3) + queue.put(4, unbounditems=queues.UNBOUND_REMOVE) + queue.put(5, unbounditems=queues.UNBOUND) """)) self.assertEqual(queue.qsize(), 5) @@ -555,13 +503,13 @@ def test_put_cleared_with_subinterpreter_multiple(self): interp1 = interpreters.create() interp2 = interpreters.create() - queue.put(1, syncobj=True) + queue.put(1) _run_output(interp1, dedent(f""" from test.support.interpreters import queues queue = queues.Queue({queue.id}) obj1 = queue.get() - queue.put(2, syncobj=True, unbound=queues.UNBOUND) - queue.put(obj1, syncobj=True, unbound=queues.UNBOUND_REMOVE) + queue.put(2, unbounditems=queues.UNBOUND) + queue.put(obj1, unbounditems=queues.UNBOUND_REMOVE) """)) _run_output(interp2, dedent(f""" from test.support.interpreters import queues @@ -572,21 +520,21 @@ def test_put_cleared_with_subinterpreter_multiple(self): self.assertEqual(queue.qsize(), 0) queue.put(3) _run_output(interp1, dedent(""" - queue.put(4, syncobj=True, unbound=queues.UNBOUND) + queue.put(4, unbounditems=queues.UNBOUND) # interp closed here - queue.put(5, syncobj=True, unbound=queues.UNBOUND_REMOVE) - queue.put(6, syncobj=True, unbound=queues.UNBOUND) + queue.put(5, unbounditems=queues.UNBOUND_REMOVE) + queue.put(6, unbounditems=queues.UNBOUND) """)) _run_output(interp2, dedent(""" - queue.put(7, syncobj=True, unbound=queues.UNBOUND_ERROR) + queue.put(7, unbounditems=queues.UNBOUND_ERROR) # interp closed here - queue.put(obj1, syncobj=True, unbound=queues.UNBOUND_ERROR) - queue.put(obj2, syncobj=True, unbound=queues.UNBOUND_REMOVE) - queue.put(8, syncobj=True, unbound=queues.UNBOUND) + queue.put(obj1, unbounditems=queues.UNBOUND_ERROR) + queue.put(obj2, unbounditems=queues.UNBOUND_REMOVE) + queue.put(8, unbounditems=queues.UNBOUND) """)) _run_output(interp1, dedent(""" - queue.put(9, syncobj=True, unbound=queues.UNBOUND_REMOVE) - queue.put(10, syncobj=True, unbound=queues.UNBOUND) + queue.put(9, unbounditems=queues.UNBOUND_REMOVE) + queue.put(10, unbounditems=queues.UNBOUND) """)) self.assertEqual(queue.qsize(), 10) @@ -642,12 +590,12 @@ def f(): break except queues.QueueEmpty: continue - queue2.put(obj, syncobj=True) + queue2.put(obj) t = threading.Thread(target=f) t.start() orig = b'spam' - queue1.put(orig, syncobj=True) + queue1.put(orig) obj = queue2.get() t.join() diff --git a/Modules/_interpchannelsmodule.c b/Modules/_interpchannelsmodule.c index 172cebcaa4884f..a8d76745fba8b7 100644 --- a/Modules/_interpchannelsmodule.c +++ b/Modules/_interpchannelsmodule.c @@ -20,9 +20,11 @@ #endif #define REGISTERS_HEAP_TYPES +#define HAS_FALLBACK #define HAS_UNBOUND_ITEMS #include "_interpreters_common.h" #undef HAS_UNBOUND_ITEMS +#undef HAS_FALLBACK #undef REGISTERS_HEAP_TYPES @@ -523,7 +525,7 @@ typedef struct _channelitem { int64_t interpid; _PyXIData_t *data; _waiting_t *waiting; - int unboundop; + unboundop_t unboundop; struct _channelitem *next; } _channelitem; @@ -536,7 +538,7 @@ _channelitem_ID(_channelitem *item) static void _channelitem_init(_channelitem *item, int64_t interpid, _PyXIData_t *data, - _waiting_t *waiting, int unboundop) + _waiting_t *waiting, unboundop_t unboundop) { if (interpid < 0) { interpid = _get_interpid(data); @@ -583,7 +585,7 @@ _channelitem_clear(_channelitem *item) static _channelitem * _channelitem_new(int64_t interpid, _PyXIData_t *data, - _waiting_t *waiting, int unboundop) + _waiting_t *waiting, unboundop_t unboundop) { _channelitem *item = GLOBAL_MALLOC(_channelitem); if (item == NULL) { @@ -694,7 +696,7 @@ _channelqueue_free(_channelqueue *queue) static int _channelqueue_put(_channelqueue *queue, int64_t interpid, _PyXIData_t *data, - _waiting_t *waiting, int unboundop) + _waiting_t *waiting, unboundop_t unboundop) { _channelitem *item = _channelitem_new(interpid, data, waiting, unboundop); if (item == NULL) { @@ -798,7 +800,7 @@ _channelqueue_remove(_channelqueue *queue, _channelitem_id_t itemid, } queue->count -= 1; - int unboundop; + unboundop_t unboundop; _channelitem_popped(item, p_data, p_waiting, &unboundop); } @@ -1083,16 +1085,18 @@ typedef struct _channel { PyThread_type_lock mutex; _channelqueue *queue; _channelends *ends; - struct { - int unboundop; + struct _channeldefaults { + unboundop_t unboundop; + xidata_fallback_t fallback; } defaults; int open; struct _channel_closing *closing; } _channel_state; static _channel_state * -_channel_new(PyThread_type_lock mutex, int unboundop) +_channel_new(PyThread_type_lock mutex, struct _channeldefaults defaults) { + assert(check_unbound(defaults.unboundop)); _channel_state *chan = GLOBAL_MALLOC(_channel_state); if (chan == NULL) { return NULL; @@ -1109,7 +1113,7 @@ _channel_new(PyThread_type_lock mutex, int unboundop) GLOBAL_FREE(chan); return NULL; } - chan->defaults.unboundop = unboundop; + chan->defaults = defaults; chan->open = 1; chan->closing = NULL; return chan; @@ -1130,7 +1134,7 @@ _channel_free(_channel_state *chan) static int _channel_add(_channel_state *chan, int64_t interpid, - _PyXIData_t *data, _waiting_t *waiting, int unboundop) + _PyXIData_t *data, _waiting_t *waiting, unboundop_t unboundop) { int res = -1; PyThread_acquire_lock(chan->mutex, WAIT_LOCK); @@ -1611,7 +1615,7 @@ _channels_release_cid_object(_channels *channels, int64_t cid) struct channel_id_and_info { int64_t id; - int unboundop; + struct _channeldefaults defaults; }; static struct channel_id_and_info * @@ -1628,7 +1632,7 @@ _channels_list_all(_channels *channels, int64_t *count) for (int64_t i=0; ref != NULL; ref = ref->next, i++) { ids[i] = (struct channel_id_and_info){ .id = ref->cid, - .unboundop = ref->chan->defaults.unboundop, + .defaults = ref->chan->defaults, }; } *count = channels->numopen; @@ -1714,13 +1718,13 @@ _channel_finish_closing(_channel_state *chan) { // Create a new channel. static int64_t -channel_create(_channels *channels, int unboundop) +channel_create(_channels *channels, struct _channeldefaults defaults) { PyThread_type_lock mutex = PyThread_allocate_lock(); if (mutex == NULL) { return ERR_CHANNEL_MUTEX_INIT; } - _channel_state *chan = _channel_new(mutex, unboundop); + _channel_state *chan = _channel_new(mutex, defaults); if (chan == NULL) { PyThread_free_lock(mutex); return -1; @@ -1752,7 +1756,7 @@ channel_destroy(_channels *channels, int64_t cid) // Optionally request to be notified when it is received. static int channel_send(_channels *channels, int64_t cid, PyObject *obj, - _waiting_t *waiting, int unboundop) + _waiting_t *waiting, unboundop_t unboundop, xidata_fallback_t fallback) { PyThreadState *tstate = _PyThreadState_GET(); PyInterpreterState *interp = tstate->interp; @@ -1779,7 +1783,7 @@ channel_send(_channels *channels, int64_t cid, PyObject *obj, PyThread_release_lock(mutex); return -1; } - if (_PyObject_GetXIData(tstate, obj, data) != 0) { + if (_PyObject_GetXIDataWithFallback(tstate, obj, fallback, data) != 0) { PyThread_release_lock(mutex); GLOBAL_FREE(data); return -1; @@ -1823,7 +1827,8 @@ channel_clear_sent(_channels *channels, int64_t cid, _waiting_t *waiting) // Like channel_send(), but strictly wait for the object to be received. static int channel_send_wait(_channels *channels, int64_t cid, PyObject *obj, - int unboundop, PY_TIMEOUT_T timeout) + unboundop_t unboundop, PY_TIMEOUT_T timeout, + xidata_fallback_t fallback) { // We use a stack variable here, so we must ensure that &waiting // is not held by any channel item at the point this function exits. @@ -1834,7 +1839,7 @@ channel_send_wait(_channels *channels, int64_t cid, PyObject *obj, } /* Queue up the object. */ - int res = channel_send(channels, cid, obj, &waiting, unboundop); + int res = channel_send(channels, cid, obj, &waiting, unboundop, fallback); if (res < 0) { assert(waiting.status == WAITING_NO_STATUS); goto finally; @@ -2005,6 +2010,20 @@ channel_is_associated(_channels *channels, int64_t cid, int64_t interpid, return (end != NULL && end->open); } +static int +channel_get_defaults(_channels *channels, int64_t cid, struct _channeldefaults *defaults) +{ + PyThread_type_lock mutex = NULL; + _channel_state *channel = NULL; + int err = _channels_lookup(channels, cid, &mutex, &channel); + if (err != 0) { + return err; + } + *defaults = channel->defaults; + PyThread_release_lock(mutex); + return 0; +} + static int _channel_get_count(_channels *channels, int64_t cid, Py_ssize_t *p_count) { @@ -2694,7 +2713,7 @@ add_channelid_type(PyObject *mod) Py_DECREF(cls); return NULL; } - if (ensure_xid_class(cls, _channelid_shared) < 0) { + if (ensure_xid_class(cls, GETDATA(_channelid_shared)) < 0) { Py_DECREF(cls); return NULL; } @@ -2797,12 +2816,12 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv) // Add and register the types. state->send_channel_type = (PyTypeObject *)Py_NewRef(send); state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv); - if (ensure_xid_class(send, _channelend_shared) < 0) { + if (ensure_xid_class(send, GETDATA(_channelend_shared)) < 0) { Py_CLEAR(state->send_channel_type); Py_CLEAR(state->recv_channel_type); return -1; } - if (ensure_xid_class(recv, _channelend_shared) < 0) { + if (ensure_xid_class(recv, GETDATA(_channelend_shared)) < 0) { (void)clear_xid_class(state->send_channel_type); Py_CLEAR(state->send_channel_type); Py_CLEAR(state->recv_channel_type); @@ -2881,20 +2900,27 @@ clear_interpreter(void *data) static PyObject * channelsmod_create(PyObject *self, PyObject *args, PyObject *kwds) { - static char *kwlist[] = {"unboundop", NULL}; - int unboundop; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "i:create", kwlist, - &unboundop)) + static char *kwlist[] = {"unboundop", "fallback", NULL}; + int unboundarg = -1; + int fallbackarg = -1; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ii:create", kwlist, + &unboundarg, &fallbackarg)) { return NULL; } - if (!check_unbound(unboundop)) { - PyErr_Format(PyExc_ValueError, - "unsupported unboundop %d", unboundop); + struct _channeldefaults defaults = {0}; + if (resolve_unboundop(unboundarg, UNBOUND_REPLACE, + &defaults.unboundop) < 0) + { + return NULL; + } + if (resolve_fallback(fallbackarg, _PyXIDATA_FULL_FALLBACK, + &defaults.fallback) < 0) + { return NULL; } - int64_t cid = channel_create(&_globals.channels, unboundop); + int64_t cid = channel_create(&_globals.channels, defaults); if (cid < 0) { (void)handle_channel_error(-1, self, cid); return NULL; @@ -2987,7 +3013,9 @@ channelsmod_list_all(PyObject *self, PyObject *Py_UNUSED(ignored)) } assert(cidobj != NULL); - PyObject *item = Py_BuildValue("Oi", cidobj, cur->unboundop); + PyObject *item = Py_BuildValue("Oii", cidobj, + cur->defaults.unboundop, + cur->defaults.fallback); Py_DECREF(cidobj); if (item == NULL) { Py_SETREF(ids, NULL); @@ -3075,40 +3103,54 @@ receive end."); static PyObject * channelsmod_send(PyObject *self, PyObject *args, PyObject *kwds) { - static char *kwlist[] = {"cid", "obj", "unboundop", "blocking", "timeout", - NULL}; + static char *kwlist[] = {"cid", "obj", "unboundop", "fallback", + "blocking", "timeout", NULL}; struct channel_id_converter_data cid_data = { .module = self, }; PyObject *obj; - int unboundop = UNBOUND_REPLACE; + int unboundarg = -1; + int fallbackarg = -1; int blocking = 1; PyObject *timeout_obj = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|i$pO:channel_send", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "O&O|ii$pO:channel_send", kwlist, channel_id_converter, &cid_data, &obj, - &unboundop, &blocking, &timeout_obj)) + &unboundarg, &fallbackarg, + &blocking, &timeout_obj)) { return NULL; } - if (!check_unbound(unboundop)) { - PyErr_Format(PyExc_ValueError, - "unsupported unboundop %d", unboundop); - return NULL; - } - int64_t cid = cid_data.cid; PY_TIMEOUT_T timeout; if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) { return NULL; } + struct _channeldefaults defaults = {-1, -1}; + if (unboundarg < 0 || fallbackarg < 0) { + int err = channel_get_defaults(&_globals.channels, cid, &defaults); + if (handle_channel_error(err, self, cid)) { + return NULL; + } + } + unboundop_t unboundop; + if (resolve_unboundop(unboundarg, defaults.unboundop, &unboundop) < 0) { + return NULL; + } + xidata_fallback_t fallback; + if (resolve_fallback(fallbackarg, defaults.fallback, &fallback) < 0) { + return NULL; + } /* Queue up the object. */ int err = 0; if (blocking) { - err = channel_send_wait(&_globals.channels, cid, obj, unboundop, timeout); + err = channel_send_wait( + &_globals.channels, cid, obj, unboundop, timeout, fallback); } else { - err = channel_send(&_globals.channels, cid, obj, NULL, unboundop); + err = channel_send( + &_globals.channels, cid, obj, NULL, unboundop, fallback); } if (handle_channel_error(err, self, cid)) { return NULL; @@ -3126,32 +3168,44 @@ By default this waits for the object to be received."); static PyObject * channelsmod_send_buffer(PyObject *self, PyObject *args, PyObject *kwds) { - static char *kwlist[] = {"cid", "obj", "unboundop", "blocking", "timeout", - NULL}; + static char *kwlist[] = {"cid", "obj", "unboundop", "fallback", + "blocking", "timeout", NULL}; struct channel_id_converter_data cid_data = { .module = self, }; PyObject *obj; - int unboundop = UNBOUND_REPLACE; - int blocking = 1; + int unboundarg = -1; + int fallbackarg = -1; + int blocking = -1; PyObject *timeout_obj = NULL; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "O&O|i$pO:channel_send_buffer", kwlist, + "O&O|ii$pO:channel_send_buffer", kwlist, channel_id_converter, &cid_data, &obj, - &unboundop, &blocking, &timeout_obj)) { - return NULL; - } - if (!check_unbound(unboundop)) { - PyErr_Format(PyExc_ValueError, - "unsupported unboundop %d", unboundop); + &unboundarg, &fallbackarg, + &blocking, &timeout_obj)) + { return NULL; } - int64_t cid = cid_data.cid; PY_TIMEOUT_T timeout; if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) { return NULL; } + struct _channeldefaults defaults = {-1, -1}; + if (unboundarg < 0 || fallbackarg < 0) { + int err = channel_get_defaults(&_globals.channels, cid, &defaults); + if (handle_channel_error(err, self, cid)) { + return NULL; + } + } + unboundop_t unboundop; + if (resolve_unboundop(unboundarg, defaults.unboundop, &unboundop) < 0) { + return NULL; + } + xidata_fallback_t fallback; + if (resolve_fallback(fallbackarg, defaults.fallback, &fallback) < 0) { + return NULL; + } PyObject *tempobj = PyMemoryView_FromObject(obj); if (tempobj == NULL) { @@ -3162,10 +3216,11 @@ channelsmod_send_buffer(PyObject *self, PyObject *args, PyObject *kwds) int err = 0; if (blocking) { err = channel_send_wait( - &_globals.channels, cid, tempobj, unboundop, timeout); + &_globals.channels, cid, tempobj, unboundop, timeout, fallback); } else { - err = channel_send(&_globals.channels, cid, tempobj, NULL, unboundop); + err = channel_send( + &_globals.channels, cid, tempobj, NULL, unboundop, fallback); } Py_DECREF(tempobj); if (handle_channel_error(err, self, cid)) { @@ -3197,7 +3252,7 @@ channelsmod_recv(PyObject *self, PyObject *args, PyObject *kwds) cid = cid_data.cid; PyObject *obj = NULL; - int unboundop = 0; + unboundop_t unboundop = 0; int err = channel_recv(&_globals.channels, cid, &obj, &unboundop); if (err == ERR_CHANNEL_EMPTY && dflt != NULL) { // Use the default. @@ -3388,17 +3443,14 @@ channelsmod_get_channel_defaults(PyObject *self, PyObject *args, PyObject *kwds) } int64_t cid = cid_data.cid; - PyThread_type_lock mutex = NULL; - _channel_state *channel = NULL; - int err = _channels_lookup(&_globals.channels, cid, &mutex, &channel); + struct _channeldefaults defaults; + int err = channel_get_defaults(&_globals.channels, cid, &defaults); if (handle_channel_error(err, self, cid)) { return NULL; } - int unboundop = channel->defaults.unboundop; - PyThread_release_lock(mutex); - PyObject *defaults = Py_BuildValue("i", unboundop); - return defaults; + PyObject *res = Py_BuildValue("ii", defaults.unboundop, defaults.fallback); + return res; } PyDoc_STRVAR(channelsmod_get_channel_defaults_doc, diff --git a/Modules/_interpqueuesmodule.c b/Modules/_interpqueuesmodule.c index 526249a0e1aec3..ee8d5ce915c0d2 100644 --- a/Modules/_interpqueuesmodule.c +++ b/Modules/_interpqueuesmodule.c @@ -9,9 +9,11 @@ #include "pycore_crossinterp.h" // _PyXIData_t #define REGISTERS_HEAP_TYPES +#define HAS_FALLBACK #define HAS_UNBOUND_ITEMS #include "_interpreters_common.h" #undef HAS_UNBOUND_ITEMS +#undef HAS_FALLBACK #undef REGISTERS_HEAP_TYPES @@ -401,14 +403,13 @@ typedef struct _queueitem { meaning the interpreter has been destroyed. */ int64_t interpid; _PyXIData_t *data; - int fmt; - int unboundop; + unboundop_t unboundop; struct _queueitem *next; } _queueitem; static void _queueitem_init(_queueitem *item, - int64_t interpid, _PyXIData_t *data, int fmt, int unboundop) + int64_t interpid, _PyXIData_t *data, unboundop_t unboundop) { if (interpid < 0) { interpid = _get_interpid(data); @@ -422,7 +423,6 @@ _queueitem_init(_queueitem *item, *item = (_queueitem){ .interpid = interpid, .data = data, - .fmt = fmt, .unboundop = unboundop, }; } @@ -446,14 +446,14 @@ _queueitem_clear(_queueitem *item) } static _queueitem * -_queueitem_new(int64_t interpid, _PyXIData_t *data, int fmt, int unboundop) +_queueitem_new(int64_t interpid, _PyXIData_t *data, int unboundop) { _queueitem *item = GLOBAL_MALLOC(_queueitem); if (item == NULL) { PyErr_NoMemory(); return NULL; } - _queueitem_init(item, interpid, data, fmt, unboundop); + _queueitem_init(item, interpid, data, unboundop); return item; } @@ -476,10 +476,9 @@ _queueitem_free_all(_queueitem *item) static void _queueitem_popped(_queueitem *item, - _PyXIData_t **p_data, int *p_fmt, int *p_unboundop) + _PyXIData_t **p_data, unboundop_t *p_unboundop) { *p_data = item->data; - *p_fmt = item->fmt; *p_unboundop = item->unboundop; // We clear them here, so they won't be released in _queueitem_clear(). item->data = NULL; @@ -527,16 +526,16 @@ typedef struct _queue { _queueitem *first; _queueitem *last; } items; - struct { - int fmt; + struct _queuedefaults { + xidata_fallback_t fallback; int unboundop; } defaults; } _queue; static int -_queue_init(_queue *queue, Py_ssize_t maxsize, int fmt, int unboundop) +_queue_init(_queue *queue, Py_ssize_t maxsize, struct _queuedefaults defaults) { - assert(check_unbound(unboundop)); + assert(check_unbound(defaults.unboundop)); PyThread_type_lock mutex = PyThread_allocate_lock(); if (mutex == NULL) { return ERR_QUEUE_ALLOC; @@ -547,10 +546,7 @@ _queue_init(_queue *queue, Py_ssize_t maxsize, int fmt, int unboundop) .items = { .maxsize = maxsize, }, - .defaults = { - .fmt = fmt, - .unboundop = unboundop, - }, + .defaults = defaults, }; return 0; } @@ -631,8 +627,7 @@ _queue_unlock(_queue *queue) } static int -_queue_add(_queue *queue, int64_t interpid, _PyXIData_t *data, - int fmt, int unboundop) +_queue_add(_queue *queue, int64_t interpid, _PyXIData_t *data, int unboundop) { int err = _queue_lock(queue); if (err < 0) { @@ -648,7 +643,7 @@ _queue_add(_queue *queue, int64_t interpid, _PyXIData_t *data, return ERR_QUEUE_FULL; } - _queueitem *item = _queueitem_new(interpid, data, fmt, unboundop); + _queueitem *item = _queueitem_new(interpid, data, unboundop); if (item == NULL) { _queue_unlock(queue); return -1; @@ -668,8 +663,7 @@ _queue_add(_queue *queue, int64_t interpid, _PyXIData_t *data, } static int -_queue_next(_queue *queue, - _PyXIData_t **p_data, int *p_fmt, int *p_unboundop) +_queue_next(_queue *queue, _PyXIData_t **p_data, int *p_unboundop) { int err = _queue_lock(queue); if (err < 0) { @@ -688,7 +682,7 @@ _queue_next(_queue *queue, } queue->items.count -= 1; - _queueitem_popped(item, p_data, p_fmt, p_unboundop); + _queueitem_popped(item, p_data, p_unboundop); _queue_unlock(queue); return 0; @@ -1035,8 +1029,7 @@ _queues_decref(_queues *queues, int64_t qid) struct queue_id_and_info { int64_t id; - int fmt; - int unboundop; + struct _queuedefaults defaults; }; static struct queue_id_and_info * @@ -1053,8 +1046,7 @@ _queues_list_all(_queues *queues, int64_t *p_count) for (int64_t i=0; ref != NULL; ref = ref->next, i++) { ids[i].id = ref->qid; assert(ref->queue != NULL); - ids[i].fmt = ref->queue->defaults.fmt; - ids[i].unboundop = ref->queue->defaults.unboundop; + ids[i].defaults = ref->queue->defaults; } *p_count = queues->count; @@ -1090,13 +1082,14 @@ _queue_free(_queue *queue) // Create a new queue. static int64_t -queue_create(_queues *queues, Py_ssize_t maxsize, int fmt, int unboundop) +queue_create(_queues *queues, Py_ssize_t maxsize, + struct _queuedefaults defaults) { _queue *queue = GLOBAL_MALLOC(_queue); if (queue == NULL) { return ERR_QUEUE_ALLOC; } - int err = _queue_init(queue, maxsize, fmt, unboundop); + int err = _queue_init(queue, maxsize, defaults); if (err < 0) { GLOBAL_FREE(queue); return (int64_t)err; @@ -1125,7 +1118,8 @@ queue_destroy(_queues *queues, int64_t qid) // Push an object onto the queue. static int -queue_put(_queues *queues, int64_t qid, PyObject *obj, int fmt, int unboundop) +queue_put(_queues *queues, int64_t qid, PyObject *obj, unboundop_t unboundop, + xidata_fallback_t fallback) { PyThreadState *tstate = PyThreadState_Get(); @@ -1138,27 +1132,27 @@ queue_put(_queues *queues, int64_t qid, PyObject *obj, int fmt, int unboundop) assert(queue != NULL); // Convert the object to cross-interpreter data. - _PyXIData_t *data = _PyXIData_New(); - if (data == NULL) { + _PyXIData_t *xidata = _PyXIData_New(); + if (xidata == NULL) { _queue_unmark_waiter(queue, queues->mutex); return -1; } - if (_PyObject_GetXIData(tstate, obj, data) != 0) { + if (_PyObject_GetXIDataWithFallback(tstate, obj, fallback, xidata) != 0) { _queue_unmark_waiter(queue, queues->mutex); - GLOBAL_FREE(data); + GLOBAL_FREE(xidata); return -1; } - assert(_PyXIData_INTERPID(data) == + assert(_PyXIData_INTERPID(xidata) == PyInterpreterState_GetID(tstate->interp)); // Add the data to the queue. int64_t interpid = -1; // _queueitem_init() will set it. - int res = _queue_add(queue, interpid, data, fmt, unboundop); + int res = _queue_add(queue, interpid, xidata, unboundop); _queue_unmark_waiter(queue, queues->mutex); if (res != 0) { // We may chain an exception here: - (void)_release_xid_data(data, 0); - GLOBAL_FREE(data); + (void)_release_xid_data(xidata, 0); + GLOBAL_FREE(xidata); return res; } @@ -1169,7 +1163,7 @@ queue_put(_queues *queues, int64_t qid, PyObject *obj, int fmt, int unboundop) // XXX Support a "wait" mutex? static int queue_get(_queues *queues, int64_t qid, - PyObject **res, int *p_fmt, int *p_unboundop) + PyObject **res, int *p_unboundop) { int err; *res = NULL; @@ -1185,7 +1179,7 @@ queue_get(_queues *queues, int64_t qid, // Pop off the next item from the queue. _PyXIData_t *data = NULL; - err = _queue_next(queue, &data, p_fmt, p_unboundop); + err = _queue_next(queue, &data, p_unboundop); _queue_unmark_waiter(queue, queues->mutex); if (err != 0) { return err; @@ -1216,6 +1210,20 @@ queue_get(_queues *queues, int64_t qid, return 0; } +static int +queue_get_defaults(_queues *queues, int64_t qid, + struct _queuedefaults *p_defaults) +{ + _queue *queue = NULL; + int err = _queues_lookup(queues, qid, &queue); + if (err != 0) { + return err; + } + *p_defaults = queue->defaults; + _queue_unmark_waiter(queue, queues->mutex); + return 0; +} + static int queue_get_maxsize(_queues *queues, int64_t qid, Py_ssize_t *p_maxsize) { @@ -1270,7 +1278,7 @@ set_external_queue_type(module_state *state, PyTypeObject *queue_type) } // Add and register the new type. - if (ensure_xid_class(queue_type, _queueobj_shared) < 0) { + if (ensure_xid_class(queue_type, GETDATA(_queueobj_shared)) < 0) { return -1; } state->queue_type = (PyTypeObject *)Py_NewRef(queue_type); @@ -1474,22 +1482,28 @@ qidarg_converter(PyObject *arg, void *ptr) static PyObject * queuesmod_create(PyObject *self, PyObject *args, PyObject *kwds) { - static char *kwlist[] = {"maxsize", "fmt", "unboundop", NULL}; + static char *kwlist[] = {"maxsize", "unboundop", "fallback", NULL}; Py_ssize_t maxsize; - int fmt; - int unboundop; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "nii:create", kwlist, - &maxsize, &fmt, &unboundop)) + int unboundarg = -1; + int fallbackarg = -1; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "n|ii:create", kwlist, + &maxsize, &unboundarg, &fallbackarg)) { return NULL; } - if (!check_unbound(unboundop)) { - PyErr_Format(PyExc_ValueError, - "unsupported unboundop %d", unboundop); + struct _queuedefaults defaults = {0}; + if (resolve_unboundop(unboundarg, UNBOUND_REPLACE, + &defaults.unboundop) < 0) + { + return NULL; + } + if (resolve_fallback(fallbackarg, _PyXIDATA_FULL_FALLBACK, + &defaults.fallback) < 0) + { return NULL; } - int64_t qid = queue_create(&_globals.queues, maxsize, fmt, unboundop); + int64_t qid = queue_create(&_globals.queues, maxsize, defaults); if (qid < 0) { (void)handle_queue_error((int)qid, self, qid); return NULL; @@ -1511,7 +1525,7 @@ queuesmod_create(PyObject *self, PyObject *args, PyObject *kwds) } PyDoc_STRVAR(queuesmod_create_doc, -"create(maxsize, fmt, unboundop) -> qid\n\ +"create(maxsize, unboundop, fallback) -> qid\n\ \n\ Create a new cross-interpreter queue and return its unique generated ID.\n\ It is a new reference as though bind() had been called on the queue.\n\ @@ -1560,8 +1574,9 @@ queuesmod_list_all(PyObject *self, PyObject *Py_UNUSED(ignored)) } struct queue_id_and_info *cur = qids; for (int64_t i=0; i < count; cur++, i++) { - PyObject *item = Py_BuildValue("Lii", cur->id, cur->fmt, - cur->unboundop); + PyObject *item = Py_BuildValue("Lii", cur->id, + cur->defaults.unboundop, + cur->defaults.fallback); if (item == NULL) { Py_SETREF(ids, NULL); break; @@ -1575,34 +1590,44 @@ queuesmod_list_all(PyObject *self, PyObject *Py_UNUSED(ignored)) } PyDoc_STRVAR(queuesmod_list_all_doc, -"list_all() -> [(qid, fmt)]\n\ +"list_all() -> [(qid, unboundop, fallback)]\n\ \n\ Return the list of IDs for all queues.\n\ -Each corresponding default format is also included."); +Each corresponding default unbound op and fallback is also included."); static PyObject * queuesmod_put(PyObject *self, PyObject *args, PyObject *kwds) { - static char *kwlist[] = {"qid", "obj", "fmt", "unboundop", NULL}; + static char *kwlist[] = {"qid", "obj", "unboundop", "fallback", NULL}; qidarg_converter_data qidarg = {0}; PyObject *obj; - int fmt; - int unboundop; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&Oii:put", kwlist, - qidarg_converter, &qidarg, &obj, &fmt, - &unboundop)) + int unboundarg = -1; + int fallbackarg = -1; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|ii$p:put", kwlist, + qidarg_converter, &qidarg, &obj, + &unboundarg, &fallbackarg)) { return NULL; } int64_t qid = qidarg.id; - if (!check_unbound(unboundop)) { - PyErr_Format(PyExc_ValueError, - "unsupported unboundop %d", unboundop); + struct _queuedefaults defaults = {-1, -1}; + if (unboundarg < 0 || fallbackarg < 0) { + int err = queue_get_defaults(&_globals.queues, qid, &defaults); + if (handle_queue_error(err, self, qid)) { + return NULL; + } + } + unboundop_t unboundop; + if (resolve_unboundop(unboundarg, defaults.unboundop, &unboundop) < 0) { + return NULL; + } + xidata_fallback_t fallback; + if (resolve_fallback(fallbackarg, defaults.fallback, &fallback) < 0) { return NULL; } /* Queue up the object. */ - int err = queue_put(&_globals.queues, qid, obj, fmt, unboundop); + int err = queue_put(&_globals.queues, qid, obj, unboundop, fallback); // This is the only place that raises QueueFull. if (handle_queue_error(err, self, qid)) { return NULL; @@ -1612,7 +1637,7 @@ queuesmod_put(PyObject *self, PyObject *args, PyObject *kwds) } PyDoc_STRVAR(queuesmod_put_doc, -"put(qid, obj, fmt)\n\ +"put(qid, obj)\n\ \n\ Add the object's data to the queue."); @@ -1628,27 +1653,26 @@ queuesmod_get(PyObject *self, PyObject *args, PyObject *kwds) int64_t qid = qidarg.id; PyObject *obj = NULL; - int fmt = 0; int unboundop = 0; - int err = queue_get(&_globals.queues, qid, &obj, &fmt, &unboundop); + int err = queue_get(&_globals.queues, qid, &obj, &unboundop); // This is the only place that raises QueueEmpty. if (handle_queue_error(err, self, qid)) { return NULL; } if (obj == NULL) { - return Py_BuildValue("Oii", Py_None, fmt, unboundop); + return Py_BuildValue("Oi", Py_None, unboundop); } - PyObject *res = Py_BuildValue("OiO", obj, fmt, Py_None); + PyObject *res = Py_BuildValue("OO", obj, Py_None); Py_DECREF(obj); return res; } PyDoc_STRVAR(queuesmod_get_doc, -"get(qid) -> (obj, fmt)\n\ +"get(qid) -> (obj, unboundop)\n\ \n\ Return a new object from the data at the front of the queue.\n\ -The object's format is also returned.\n\ +The unbound op is also returned.\n\ \n\ If there is nothing to receive then raise QueueEmpty."); @@ -1748,17 +1772,14 @@ queuesmod_get_queue_defaults(PyObject *self, PyObject *args, PyObject *kwds) } int64_t qid = qidarg.id; - _queue *queue = NULL; - int err = _queues_lookup(&_globals.queues, qid, &queue); + struct _queuedefaults defaults; + int err = queue_get_defaults(&_globals.queues, qid, &defaults); if (handle_queue_error(err, self, qid)) { return NULL; } - int fmt = queue->defaults.fmt; - int unboundop = queue->defaults.unboundop; - _queue_unmark_waiter(queue, _globals.queues.mutex); - PyObject *defaults = Py_BuildValue("ii", fmt, unboundop); - return defaults; + PyObject *res = Py_BuildValue("ii", defaults.unboundop, defaults.fallback); + return res; } PyDoc_STRVAR(queuesmod_get_queue_defaults_doc, diff --git a/Modules/_interpreters_common.h b/Modules/_interpreters_common.h index edd65577284a20..40fd51d752e324 100644 --- a/Modules/_interpreters_common.h +++ b/Modules/_interpreters_common.h @@ -5,8 +5,10 @@ _RESOLVE_MODINIT_FUNC_NAME(NAME) +#define GETDATA(FUNC) ((_PyXIData_getdata_t){.basic=FUNC}) + static int -ensure_xid_class(PyTypeObject *cls, xidatafunc getdata) +ensure_xid_class(PyTypeObject *cls, _PyXIData_getdata_t getdata) { PyThreadState *tstate = PyThreadState_Get(); return _PyXIData_RegisterClass(tstate, cls, getdata); @@ -37,10 +39,37 @@ _get_interpid(_PyXIData_t *data) } +#ifdef HAS_FALLBACK +static int +resolve_fallback(int arg, xidata_fallback_t dflt, + xidata_fallback_t *p_fallback) +{ + if (arg < 0) { + *p_fallback = dflt; + return 0; + } + xidata_fallback_t fallback; + if (arg == _PyXIDATA_XIDATA_ONLY) { + fallback =_PyXIDATA_XIDATA_ONLY; + } + else if (arg == _PyXIDATA_FULL_FALLBACK) { + fallback = _PyXIDATA_FULL_FALLBACK; + } + else { + PyErr_Format(PyExc_ValueError, "unsupported fallback %d", arg); + return -1; + } + *p_fallback = fallback; + return 0; +} +#endif + + /* unbound items ************************************************************/ #ifdef HAS_UNBOUND_ITEMS +typedef int unboundop_t; #define UNBOUND_REMOVE 1 #define UNBOUND_ERROR 2 #define UNBOUND_REPLACE 3 @@ -51,6 +80,7 @@ _get_interpid(_PyXIData_t *data) // object is released but the underlying data is copied (with the "raw" // allocator) and used when the item is popped off the queue. +#ifndef NDEBUG static int check_unbound(int unboundop) { @@ -63,5 +93,31 @@ check_unbound(int unboundop) return 0; } } +#endif + +static int +resolve_unboundop(int arg, unboundop_t dflt, unboundop_t *p_unboundop) +{ + if (arg < 0) { + *p_unboundop = dflt; + return 0; + } + unboundop_t op; + if (arg == UNBOUND_REMOVE) { + op = UNBOUND_REMOVE; + } + else if (arg == UNBOUND_ERROR) { + op = UNBOUND_ERROR; + } + else if (arg == UNBOUND_REPLACE) { + op = UNBOUND_REPLACE; + } + else { + PyErr_Format(PyExc_ValueError, "unsupported unboundop %d", arg); + return -1; + } + *p_unboundop = op; + return 0; +} #endif diff --git a/Modules/_interpretersmodule.c b/Modules/_interpretersmodule.c index 77678f7c126005..376517ab92360f 100644 --- a/Modules/_interpretersmodule.c +++ b/Modules/_interpretersmodule.c @@ -8,6 +8,7 @@ #include "Python.h" #include "pycore_code.h" // _PyCode_HAS_EXECUTORS() #include "pycore_crossinterp.h" // _PyXIData_t +#include "pycore_pyerrors.h" // _PyErr_GetRaisedException() #include "pycore_interp.h" // _PyInterpreterState_IDIncref() #include "pycore_modsupport.h" // _PyArg_BadArgument() #include "pycore_namespace.h" // _PyNamespace_New() @@ -286,7 +287,7 @@ register_memoryview_xid(PyObject *mod, PyTypeObject **p_state) *p_state = cls; // Register XID for the builtin memoryview type. - if (ensure_xid_class(&PyMemoryView_Type, _pybuffer_shared) < 0) { + if (ensure_xid_class(&PyMemoryView_Type, GETDATA(_pybuffer_shared)) < 0) { return -1; } // We don't ever bother un-registering memoryview. @@ -359,96 +360,6 @@ _get_current_xibufferview_type(void) } -/* Python code **************************************************************/ - -static const char * -check_code_str(PyUnicodeObject *text) -{ - assert(text != NULL); - if (PyUnicode_GET_LENGTH(text) == 0) { - return "too short"; - } - - // XXX Verify that it parses? - - return NULL; -} - -static const char * -check_code_object(PyCodeObject *code) -{ - assert(code != NULL); - if (code->co_argcount > 0 - || code->co_posonlyargcount > 0 - || code->co_kwonlyargcount > 0 - || code->co_flags & (CO_VARARGS | CO_VARKEYWORDS)) - { - return "arguments not supported"; - } - if (code->co_ncellvars > 0) { - return "closures not supported"; - } - // We trust that no code objects under co_consts have unbound cell vars. - - if (_PyCode_HAS_EXECUTORS(code) || _PyCode_HAS_INSTRUMENTATION(code)) { - return "only basic functions are supported"; - } - if (code->_co_monitoring != NULL) { - return "only basic functions are supported"; - } - if (code->co_extra != NULL) { - return "only basic functions are supported"; - } - - return NULL; -} - -#define RUN_TEXT 1 -#define RUN_CODE 2 - -static const char * -get_code_str(PyObject *arg, Py_ssize_t *len_p, PyObject **bytes_p, int *flags_p) -{ - const char *codestr = NULL; - Py_ssize_t len = -1; - PyObject *bytes_obj = NULL; - int flags = 0; - - if (PyUnicode_Check(arg)) { - assert(PyUnicode_Check(arg) - && (check_code_str((PyUnicodeObject *)arg) == NULL)); - codestr = PyUnicode_AsUTF8AndSize(arg, &len); - if (codestr == NULL) { - return NULL; - } - if (strlen(codestr) != (size_t)len) { - PyErr_SetString(PyExc_ValueError, - "source code string cannot contain null bytes"); - return NULL; - } - flags = RUN_TEXT; - } - else { - assert(PyCode_Check(arg) - && (check_code_object((PyCodeObject *)arg) == NULL)); - flags = RUN_CODE; - - // Serialize the code object. - bytes_obj = PyMarshal_WriteObjectToString(arg, Py_MARSHAL_VERSION); - if (bytes_obj == NULL) { - return NULL; - } - codestr = PyBytes_AS_STRING(bytes_obj); - len = PyBytes_GET_SIZE(bytes_obj); - } - - *flags_p = flags; - *bytes_p = bytes_obj; - *len_p = len; - return codestr; -} - - /* interpreter-specific code ************************************************/ static int @@ -512,22 +423,14 @@ config_from_object(PyObject *configobj, PyInterpreterConfig *config) static int -_run_script(PyObject *ns, const char *codestr, Py_ssize_t codestrlen, int flags) +_run_script(_PyXIData_t *script, PyObject *ns) { - PyObject *result = NULL; - if (flags & RUN_TEXT) { - result = PyRun_StringFlags(codestr, Py_file_input, ns, ns, NULL); - } - else if (flags & RUN_CODE) { - PyObject *code = PyMarshal_ReadObjectFromString(codestr, codestrlen); - if (code != NULL) { - result = PyEval_EvalCode(code, ns, ns); - Py_DECREF(code); - } - } - else { - Py_UNREACHABLE(); + PyObject *code = _PyXIData_NewObject(script); + if (code == NULL) { + return -1; } + PyObject *result = PyEval_EvalCode(code, ns, ns); + Py_DECREF(code); if (result == NULL) { return -1; } @@ -536,48 +439,59 @@ _run_script(PyObject *ns, const char *codestr, Py_ssize_t codestrlen, int flags) } static int -_run_in_interpreter(PyInterpreterState *interp, - const char *codestr, Py_ssize_t codestrlen, - PyObject *shareables, int flags, +_exec_in_interpreter(PyThreadState *tstate, PyInterpreterState *interp, + _PyXIData_t *script, PyObject *shareables, PyObject **p_excinfo) { - assert(!PyErr_Occurred()); - _PyXI_session session = {0}; + assert(!_PyErr_Occurred(tstate)); + _PyXI_session *session = _PyXI_NewSession(); + if (session == NULL) { + return -1; + } // Prep and switch interpreters. - if (_PyXI_Enter(&session, interp, shareables) < 0) { - if (PyErr_Occurred()) { + if (_PyXI_Enter(session, interp, shareables) < 0) { + if (_PyErr_Occurred(tstate)) { // If an error occured at this step, it means that interp // was not prepared and switched. + _PyXI_FreeSession(session); return -1; } // Now, apply the error from another interpreter: - PyObject *excinfo = _PyXI_ApplyError(session.error); + PyObject *excinfo = _PyXI_ApplyCapturedException(session); if (excinfo != NULL) { *p_excinfo = excinfo; } assert(PyErr_Occurred()); + _PyXI_FreeSession(session); return -1; } // Run the script. - int res = _run_script(session.main_ns, codestr, codestrlen, flags); + int res = -1; + PyObject *mainns = _PyXI_GetMainNamespace(session); + if (mainns == NULL) { + goto finally; + } + res = _run_script(script, mainns); +finally: // Clean up and switch back. - _PyXI_Exit(&session); + _PyXI_Exit(session); // Propagate any exception out to the caller. assert(!PyErr_Occurred()); if (res < 0) { - PyObject *excinfo = _PyXI_ApplyCapturedException(&session); + PyObject *excinfo = _PyXI_ApplyCapturedException(session); if (excinfo != NULL) { *p_excinfo = excinfo; } } else { - assert(!_PyXI_HasCapturedException(&session)); + assert(!_PyXI_HasCapturedException(session)); } + _PyXI_FreeSession(session); return res; } @@ -922,22 +836,27 @@ interp_set___main___attrs(PyObject *self, PyObject *args, PyObject *kwargs) } } - _PyXI_session session = {0}; + _PyXI_session *session = _PyXI_NewSession(); + if (session == NULL) { + return NULL; + } // Prep and switch interpreters, including apply the updates. - if (_PyXI_Enter(&session, interp, updates) < 0) { + if (_PyXI_Enter(session, interp, updates) < 0) { if (!PyErr_Occurred()) { - _PyXI_ApplyCapturedException(&session); + _PyXI_ApplyCapturedException(session); assert(PyErr_Occurred()); } else { - assert(!_PyXI_HasCapturedException(&session)); + assert(!_PyXI_HasCapturedException(session)); } + _PyXI_FreeSession(session); return NULL; } // Clean up and switch back. - _PyXI_Exit(&session); + _PyXI_Exit(session); + _PyXI_FreeSession(session); Py_RETURN_NONE; } @@ -948,122 +867,38 @@ PyDoc_STRVAR(set___main___attrs_doc, Bind the given attributes in the interpreter's __main__ module."); -static PyUnicodeObject * -convert_script_arg(PyObject *arg, const char *fname, const char *displayname, - const char *expected) -{ - PyUnicodeObject *str = NULL; - if (PyUnicode_CheckExact(arg)) { - str = (PyUnicodeObject *)Py_NewRef(arg); - } - else if (PyUnicode_Check(arg)) { - // XXX str = PyUnicode_FromObject(arg); - str = (PyUnicodeObject *)Py_NewRef(arg); - } - else { - _PyArg_BadArgument(fname, displayname, expected, arg); - return NULL; - } - - const char *err = check_code_str(str); - if (err != NULL) { - Py_DECREF(str); - PyErr_Format(PyExc_ValueError, - "%.200s(): bad script text (%s)", fname, err); - return NULL; - } - - return str; -} - -static PyCodeObject * -convert_code_arg(PyObject *arg, const char *fname, const char *displayname, - const char *expected) +static void +unwrap_not_shareable(PyThreadState *tstate) { - const char *kind = NULL; - PyCodeObject *code = NULL; - if (PyFunction_Check(arg)) { - if (PyFunction_GetClosure(arg) != NULL) { - PyErr_Format(PyExc_ValueError, - "%.200s(): closures not supported", fname); - return NULL; - } - code = (PyCodeObject *)PyFunction_GetCode(arg); - if (code == NULL) { - if (PyErr_Occurred()) { - // This chains. - PyErr_Format(PyExc_ValueError, - "%.200s(): bad func", fname); - } - else { - PyErr_Format(PyExc_ValueError, - "%.200s(): func.__code__ missing", fname); - } - return NULL; - } - Py_INCREF(code); - kind = "func"; + PyObject *exctype = _PyXIData_GetNotShareableErrorType(tstate); + if (!_PyErr_ExceptionMatches(tstate, exctype)) { + return; } - else if (PyCode_Check(arg)) { - code = (PyCodeObject *)Py_NewRef(arg); - kind = "code object"; + PyObject *exc = _PyErr_GetRaisedException(tstate); + PyObject *cause = PyException_GetCause(exc); + if (cause != NULL) { + Py_DECREF(exc); + exc = cause; } else { - _PyArg_BadArgument(fname, displayname, expected, arg); - return NULL; - } - - const char *err = check_code_object(code); - if (err != NULL) { - Py_DECREF(code); - PyErr_Format(PyExc_ValueError, - "%.200s(): bad %s (%s)", fname, kind, err); - return NULL; - } - - return code; -} - -static int -_interp_exec(PyObject *self, PyInterpreterState *interp, - PyObject *code_arg, PyObject *shared_arg, PyObject **p_excinfo) -{ - if (shared_arg != NULL && !PyDict_CheckExact(shared_arg)) { - PyErr_SetString(PyExc_TypeError, "expected 'shared' to be a dict"); - return -1; - } - - // Extract code. - Py_ssize_t codestrlen = -1; - PyObject *bytes_obj = NULL; - int flags = 0; - const char *codestr = get_code_str(code_arg, - &codestrlen, &bytes_obj, &flags); - if (codestr == NULL) { - return -1; - } - - // Run the code in the interpreter. - int res = _run_in_interpreter(interp, codestr, codestrlen, - shared_arg, flags, p_excinfo); - Py_XDECREF(bytes_obj); - if (res < 0) { - return -1; + assert(PyException_GetContext(exc) == NULL); } - - return 0; + _PyErr_SetRaisedException(tstate, exc); } static PyObject * interp_exec(PyObject *self, PyObject *args, PyObject *kwds) { +#define FUNCNAME MODULE_NAME_STR ".exec" + PyThreadState *tstate = _PyThreadState_GET(); static char *kwlist[] = {"id", "code", "shared", "restrict", NULL}; PyObject *id, *code; PyObject *shared = NULL; int restricted = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "OO|O$p:" MODULE_NAME_STR ".exec", kwlist, - &id, &code, &shared, &restricted)) + "OO|O!$p:" FUNCNAME, kwlist, + &id, &code, &PyDict_Type, &shared, + &restricted)) { return NULL; } @@ -1075,27 +910,23 @@ interp_exec(PyObject *self, PyObject *args, PyObject *kwds) return NULL; } - const char *expected = "a string, a function, or a code object"; - if (PyUnicode_Check(code)) { - code = (PyObject *)convert_script_arg(code, MODULE_NAME_STR ".exec", - "argument 2", expected); - } - else { - code = (PyObject *)convert_code_arg(code, MODULE_NAME_STR ".exec", - "argument 2", expected); - } - if (code == NULL) { + // We don't need the script to be "pure", which means it can use + // global variables. They will be resolved against __main__. + _PyXIData_t xidata = {0}; + if (_PyCode_GetScriptXIData(tstate, code, &xidata) < 0) { + unwrap_not_shareable(tstate); return NULL; } PyObject *excinfo = NULL; - int res = _interp_exec(self, interp, code, shared, &excinfo); - Py_DECREF(code); + int res = _exec_in_interpreter(tstate, interp, &xidata, shared, &excinfo); + _PyXIData_Release(&xidata); if (res < 0) { assert((excinfo == NULL) != (PyErr_Occurred() == NULL)); return excinfo; } Py_RETURN_NONE; +#undef FUNCNAME } PyDoc_STRVAR(exec_doc, @@ -1118,13 +949,16 @@ is ignored, including its __globals__ dict."); static PyObject * interp_run_string(PyObject *self, PyObject *args, PyObject *kwds) { +#define FUNCNAME MODULE_NAME_STR ".run_string" + PyThreadState *tstate = _PyThreadState_GET(); static char *kwlist[] = {"id", "script", "shared", "restrict", NULL}; PyObject *id, *script; PyObject *shared = NULL; int restricted = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "OU|O$p:" MODULE_NAME_STR ".run_string", - kwlist, &id, &script, &shared, &restricted)) + "OU|O!$p:" FUNCNAME, kwlist, + &id, &script, &PyDict_Type, &shared, + &restricted)) { return NULL; } @@ -1136,20 +970,26 @@ interp_run_string(PyObject *self, PyObject *args, PyObject *kwds) return NULL; } - script = (PyObject *)convert_script_arg(script, MODULE_NAME_STR ".run_string", - "argument 2", "a string"); - if (script == NULL) { + if (PyFunction_Check(script) || PyCode_Check(script)) { + _PyArg_BadArgument(FUNCNAME, "argument 2", "a string", script); + return NULL; + } + + _PyXIData_t xidata = {0}; + if (_PyCode_GetScriptXIData(tstate, script, &xidata) < 0) { + unwrap_not_shareable(tstate); return NULL; } PyObject *excinfo = NULL; - int res = _interp_exec(self, interp, script, shared, &excinfo); - Py_DECREF(script); + int res = _exec_in_interpreter(tstate, interp, &xidata, shared, &excinfo); + _PyXIData_Release(&xidata); if (res < 0) { assert((excinfo == NULL) != (PyErr_Occurred() == NULL)); return excinfo; } Py_RETURN_NONE; +#undef FUNCNAME } PyDoc_STRVAR(run_string_doc, @@ -1162,13 +1002,16 @@ Execute the provided string in the identified interpreter.\n\ static PyObject * interp_run_func(PyObject *self, PyObject *args, PyObject *kwds) { +#define FUNCNAME MODULE_NAME_STR ".run_func" + PyThreadState *tstate = _PyThreadState_GET(); static char *kwlist[] = {"id", "func", "shared", "restrict", NULL}; PyObject *id, *func; PyObject *shared = NULL; int restricted = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "OO|O$p:" MODULE_NAME_STR ".run_func", - kwlist, &id, &func, &shared, &restricted)) + "OO|O!$p:" FUNCNAME, kwlist, + &id, &func, &PyDict_Type, &shared, + &restricted)) { return NULL; } @@ -1180,21 +1023,35 @@ interp_run_func(PyObject *self, PyObject *args, PyObject *kwds) return NULL; } - PyCodeObject *code = convert_code_arg(func, MODULE_NAME_STR ".exec", - "argument 2", - "a function or a code object"); - if (code == NULL) { + // We don't worry about checking globals. They will be resolved + // against __main__. + PyObject *code; + if (PyFunction_Check(func)) { + code = PyFunction_GET_CODE(func); + } + else if (PyCode_Check(func)) { + code = func; + } + else { + _PyArg_BadArgument(FUNCNAME, "argument 2", "a function", func); + return NULL; + } + + _PyXIData_t xidata = {0}; + if (_PyCode_GetScriptXIData(tstate, code, &xidata) < 0) { + unwrap_not_shareable(tstate); return NULL; } PyObject *excinfo = NULL; - int res = _interp_exec(self, interp, (PyObject *)code, shared, &excinfo); - Py_DECREF(code); + int res = _exec_in_interpreter(tstate, interp, &xidata, shared, &excinfo); + _PyXIData_Release(&xidata); if (res < 0) { assert((excinfo == NULL) != (PyErr_Occurred() == NULL)); return excinfo; } Py_RETURN_NONE; +#undef FUNCNAME } PyDoc_STRVAR(run_func_doc, @@ -1209,6 +1066,8 @@ are not supported. Methods and other callables are not supported either.\n\ static PyObject * interp_call(PyObject *self, PyObject *args, PyObject *kwds) { +#define FUNCNAME MODULE_NAME_STR ".call" + PyThreadState *tstate = _PyThreadState_GET(); static char *kwlist[] = {"id", "callable", "args", "kwargs", "restrict", NULL}; PyObject *id, *callable; @@ -1216,7 +1075,7 @@ interp_call(PyObject *self, PyObject *args, PyObject *kwds) PyObject *kwargs_obj = NULL; int restricted = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "OO|OO$p:" MODULE_NAME_STR ".call", kwlist, + "OO|OO$p:" FUNCNAME, kwlist, &id, &callable, &args_obj, &kwargs_obj, &restricted)) { @@ -1231,28 +1090,29 @@ interp_call(PyObject *self, PyObject *args, PyObject *kwds) } if (args_obj != NULL) { - PyErr_SetString(PyExc_ValueError, "got unexpected args"); + _PyErr_SetString(tstate, PyExc_ValueError, "got unexpected args"); return NULL; } if (kwargs_obj != NULL) { - PyErr_SetString(PyExc_ValueError, "got unexpected kwargs"); + _PyErr_SetString(tstate, PyExc_ValueError, "got unexpected kwargs"); return NULL; } - PyObject *code = (PyObject *)convert_code_arg(callable, MODULE_NAME_STR ".call", - "argument 2", "a function"); - if (code == NULL) { + _PyXIData_t xidata = {0}; + if (_PyCode_GetPureScriptXIData(tstate, callable, &xidata) < 0) { + unwrap_not_shareable(tstate); return NULL; } PyObject *excinfo = NULL; - int res = _interp_exec(self, interp, code, NULL, &excinfo); - Py_DECREF(code); + int res = _exec_in_interpreter(tstate, interp, &xidata, NULL, &excinfo); + _PyXIData_Release(&xidata); if (res < 0) { assert((excinfo == NULL) != (PyErr_Occurred() == NULL)); return excinfo; } Py_RETURN_NONE; +#undef FUNCNAME } PyDoc_STRVAR(call_doc, diff --git a/Modules/_testinternalcapi.c b/Modules/_testinternalcapi.c index 92f744c5a5fc70..9659f9630c17c3 100644 --- a/Modules/_testinternalcapi.c +++ b/Modules/_testinternalcapi.c @@ -1974,6 +1974,14 @@ get_crossinterp_data(PyObject *self, PyObject *args, PyObject *kwargs) goto error; } } + else if (strcmp(mode, "fallback") == 0) { + xidata_fallback_t fallback = _PyXIDATA_FULL_FALLBACK; + if (_PyObject_GetXIDataWithFallback( + tstate, obj, fallback, xidata) != 0) + { + goto error; + } + } else if (strcmp(mode, "pickle") == 0) { if (_PyPickle_GetXIData(tstate, obj, xidata) != 0) { goto error; diff --git a/Python/crossinterp.c b/Python/crossinterp.c index 725d6009f84014..634d3d5d2b171e 100644 --- a/Python/crossinterp.c +++ b/Python/crossinterp.c @@ -210,16 +210,16 @@ _Py_CallInInterpreterAndRawFree(PyInterpreterState *interp, /* cross-interpreter data */ /**************************/ -/* registry of {type -> xidatafunc} */ +/* registry of {type -> _PyXIData_getdata_t} */ -/* For now we use a global registry of shareable classes. An - alternative would be to add a tp_* slot for a class's - xidatafunc. It would be simpler and more efficient. */ +/* For now we use a global registry of shareable classes. + An alternative would be to add a tp_* slot for a class's + _PyXIData_getdata_t. It would be simpler and more efficient. */ static void xid_lookup_init(_PyXIData_lookup_t *); static void xid_lookup_fini(_PyXIData_lookup_t *); struct _dlcontext; -static xidatafunc lookup_getdata(struct _dlcontext *, PyObject *); +static _PyXIData_getdata_t lookup_getdata(struct _dlcontext *, PyObject *); #include "crossinterp_data_lookup.h" @@ -343,7 +343,7 @@ _set_xid_lookup_failure(PyThreadState *tstate, PyObject *obj, const char *msg, set_notshareableerror(tstate, cause, 0, msg); } else { - msg = "%S does not support cross-interpreter data"; + msg = "%R does not support cross-interpreter data"; format_notshareableerror(tstate, cause, 0, msg, obj); } } @@ -356,8 +356,8 @@ _PyObject_CheckXIData(PyThreadState *tstate, PyObject *obj) if (get_lookup_context(tstate, &ctx) < 0) { return -1; } - xidatafunc getdata = lookup_getdata(&ctx, obj); - if (getdata == NULL) { + _PyXIData_getdata_t getdata = lookup_getdata(&ctx, obj); + if (getdata.basic == NULL && getdata.fallback == NULL) { if (!_PyErr_Occurred(tstate)) { _set_xid_lookup_failure(tstate, obj, NULL, NULL); } @@ -388,9 +388,9 @@ _check_xidata(PyThreadState *tstate, _PyXIData_t *xidata) return 0; } -int -_PyObject_GetXIData(PyThreadState *tstate, - PyObject *obj, _PyXIData_t *xidata) +static int +_get_xidata(PyThreadState *tstate, + PyObject *obj, xidata_fallback_t fallback, _PyXIData_t *xidata) { PyInterpreterState *interp = tstate->interp; @@ -398,6 +398,7 @@ _PyObject_GetXIData(PyThreadState *tstate, assert(xidata->obj == NULL); if (xidata->data != NULL || xidata->obj != NULL) { _PyErr_SetString(tstate, PyExc_ValueError, "xidata not cleared"); + return -1; } // Call the "getdata" func for the object. @@ -406,8 +407,8 @@ _PyObject_GetXIData(PyThreadState *tstate, return -1; } Py_INCREF(obj); - xidatafunc getdata = lookup_getdata(&ctx, obj); - if (getdata == NULL) { + _PyXIData_getdata_t getdata = lookup_getdata(&ctx, obj); + if (getdata.basic == NULL && getdata.fallback == NULL) { if (PyErr_Occurred()) { Py_DECREF(obj); return -1; @@ -419,7 +420,9 @@ _PyObject_GetXIData(PyThreadState *tstate, } return -1; } - int res = getdata(tstate, obj, xidata); + int res = getdata.basic != NULL + ? getdata.basic(tstate, obj, xidata) + : getdata.fallback(tstate, obj, fallback, xidata); Py_DECREF(obj); if (res != 0) { PyObject *cause = _PyErr_GetRaisedException(tstate); @@ -439,6 +442,51 @@ _PyObject_GetXIData(PyThreadState *tstate, return 0; } +int +_PyObject_GetXIData(PyThreadState *tstate, + PyObject *obj, _PyXIData_t *xidata) +{ + return _get_xidata(tstate, obj, _PyXIDATA_XIDATA_ONLY, xidata); +} + +int +_PyObject_GetXIDataWithFallback(PyThreadState *tstate, + PyObject *obj, xidata_fallback_t fallback, + _PyXIData_t *xidata) +{ + switch (fallback) { + case _PyXIDATA_XIDATA_ONLY: + return _get_xidata(tstate, obj, fallback, xidata); + case _PyXIDATA_FULL_FALLBACK: + if (_get_xidata(tstate, obj, fallback, xidata) == 0) { + return 0; + } + PyObject *exc = _PyErr_GetRaisedException(tstate); + if (PyFunction_Check(obj)) { + if (_PyFunction_GetXIData(tstate, obj, xidata) == 0) { + Py_DECREF(exc); + return 0; + } + _PyErr_Clear(tstate); + } + // We could try _PyMarshal_GetXIData() but we won't for now. + if (_PyPickle_GetXIData(tstate, obj, xidata) == 0) { + Py_DECREF(exc); + return 0; + } + // Raise the original exception. + _PyErr_SetRaisedException(tstate, exc); + return -1; + default: +#ifdef Py_DEBUG + Py_UNREACHABLE(); +#endif + _PyErr_SetString(tstate, PyExc_SystemError, + "unknown xidata fallback"); + return -1; + } +} + /* pickle C-API */ @@ -1796,6 +1844,7 @@ _sharednsitem_set_value(_PyXI_namespace_item *item, PyObject *value) return -1; } PyThreadState *tstate = PyThreadState_Get(); + // XXX Use _PyObject_GetXIDataWithFallback()? if (_PyObject_GetXIData(tstate, value, item->xidata) != 0) { PyMem_RawFree(item->xidata); item->xidata = NULL; @@ -1870,156 +1919,212 @@ _sharednsitem_apply(_PyXI_namespace_item *item, PyObject *ns, PyObject *dflt) return res; } -struct _sharedns { - Py_ssize_t len; - _PyXI_namespace_item *items; -}; -static _PyXI_namespace * -_sharedns_new(void) -{ - _PyXI_namespace *ns = PyMem_RawCalloc(sizeof(_PyXI_namespace), 1); - if (ns == NULL) { - PyErr_NoMemory(); - return NULL; - } - *ns = (_PyXI_namespace){ 0 }; - return ns; -} +typedef struct { + Py_ssize_t maxitems; + Py_ssize_t numnames; + Py_ssize_t numvalues; + _PyXI_namespace_item items[1]; +} _PyXI_namespace; +#ifndef NDEBUG static int -_sharedns_is_initialized(_PyXI_namespace *ns) +_sharedns_check_counts(_PyXI_namespace *ns) { - if (ns->len == 0) { - assert(ns->items == NULL); + if (ns->maxitems <= 0) { + return 0; + } + if (ns->numnames < 0) { + return 0; + } + if (ns->numnames > ns->maxitems) { + return 0; + } + if (ns->numvalues < 0) { + return 0; + } + if (ns->numvalues > ns->numnames) { return 0; } - - assert(ns->len > 0); - assert(ns->items != NULL); - assert(_sharednsitem_is_initialized(&ns->items[0])); - assert(ns->len == 1 - || _sharednsitem_is_initialized(&ns->items[ns->len - 1])); return 1; } -#define HAS_COMPLETE_DATA 1 -#define HAS_PARTIAL_DATA 2 - static int -_sharedns_has_xidata(_PyXI_namespace *ns, int64_t *p_interpid) +_sharedns_check_consistency(_PyXI_namespace *ns) { - // We expect _PyXI_namespace to always be initialized. - assert(_sharedns_is_initialized(ns)); - int res = 0; - _PyXI_namespace_item *item0 = &ns->items[0]; - if (!_sharednsitem_is_initialized(item0)) { + if (!_sharedns_check_counts(ns)) { return 0; } - int64_t interpid0 = -1; - if (!_sharednsitem_has_value(item0, &interpid0)) { - return 0; + + Py_ssize_t i = 0; + _PyXI_namespace_item *item; + if (ns->numvalues > 0) { + item = &ns->items[0]; + if (!_sharednsitem_is_initialized(item)) { + return 0; + } + int64_t interpid0 = -1; + if (!_sharednsitem_has_value(item, &interpid0)) { + return 0; + } + i += 1; + for (; i < ns->numvalues; i++) { + item = &ns->items[i]; + if (!_sharednsitem_is_initialized(item)) { + return 0; + } + int64_t interpid = -1; + if (!_sharednsitem_has_value(item, &interpid)) { + return 0; + } + if (interpid != interpid0) { + return 0; + } + } } - if (ns->len > 1) { - // At this point we know it is has at least partial data. - _PyXI_namespace_item *itemN = &ns->items[ns->len-1]; - if (!_sharednsitem_is_initialized(itemN)) { - res = HAS_PARTIAL_DATA; - goto finally; + for (; i < ns->numnames; i++) { + item = &ns->items[i]; + if (!_sharednsitem_is_initialized(item)) { + return 0; } - int64_t interpidN = -1; - if (!_sharednsitem_has_value(itemN, &interpidN)) { - res = HAS_PARTIAL_DATA; - goto finally; + if (_sharednsitem_has_value(item, NULL)) { + return 0; } - assert(interpidN == interpid0); } - res = HAS_COMPLETE_DATA; - *p_interpid = interpid0; - -finally: - return res; + for (; i < ns->maxitems; i++) { + item = &ns->items[i]; + if (_sharednsitem_is_initialized(item)) { + return 0; + } + if (_sharednsitem_has_value(item, NULL)) { + return 0; + } + } + return 1; } +#endif -static void -_sharedns_clear(_PyXI_namespace *ns) +static _PyXI_namespace * +_sharedns_alloc(Py_ssize_t maxitems) { - if (!_sharedns_is_initialized(ns)) { - return; + if (maxitems < 0) { + if (!PyErr_Occurred()) { + PyErr_BadInternalCall(); + } + return NULL; + } + else if (maxitems == 0) { + PyErr_SetString(PyExc_ValueError, "empty namespaces not allowed"); + return NULL; } - // If the cross-interpreter data were allocated as part of - // _PyXI_namespace_item (instead of dynamically), this is where - // we would need verify that we are clearing the items in the - // correct interpreter, to avoid a race with releasing the XI data - // via a pending call. See _sharedns_has_xidata(). - for (Py_ssize_t i=0; i < ns->len; i++) { - _sharednsitem_clear(&ns->items[i]); + // Check for overflow. + size_t fixedsize = sizeof(_PyXI_namespace) - sizeof(_PyXI_namespace_item); + if ((size_t)maxitems > + ((size_t)PY_SSIZE_T_MAX - fixedsize) / sizeof(_PyXI_namespace_item)) + { + PyErr_NoMemory(); + return NULL; } - PyMem_RawFree(ns->items); - ns->items = NULL; - ns->len = 0; + + // Allocate the value, including items. + size_t size = fixedsize + sizeof(_PyXI_namespace_item) * maxitems; + + _PyXI_namespace *ns = PyMem_RawCalloc(size, 1); + if (ns == NULL) { + PyErr_NoMemory(); + return NULL; + } + ns->maxitems = maxitems; + assert(_sharedns_check_consistency(ns)); + return ns; } static void _sharedns_free(_PyXI_namespace *ns) { - _sharedns_clear(ns); + // If we weren't always dynamically allocating the cross-interpreter + // data in each item then we would need to use a pending call + // to call _sharedns_free(), to avoid the race between freeing + // the shared namespace and releasing the XI data. + assert(_sharedns_check_counts(ns)); + Py_ssize_t i = 0; + _PyXI_namespace_item *item; + if (ns->numvalues > 0) { + // One or more items may have interpreter-specific data. +#ifndef NDEBUG + int64_t interpid = PyInterpreterState_GetID(PyInterpreterState_Get()); + int64_t interpid_i; +#endif + for (; i < ns->numvalues; i++) { + item = &ns->items[i]; + assert(_sharednsitem_is_initialized(item)); + // While we do want to ensure consistency across items, + // technically they don't need to match the current + // interpreter. However, we keep the constraint for + // simplicity, by giving _PyXI_FreeNamespace() the exclusive + // responsibility of dealing with the owning interpreter. + assert(_sharednsitem_has_value(item, &interpid_i)); + assert(interpid_i == interpid); + _sharednsitem_clear(item); + } + } + for (; i < ns->numnames; i++) { + item = &ns->items[i]; + assert(_sharednsitem_is_initialized(item)); + assert(!_sharednsitem_has_value(item, NULL)); + _sharednsitem_clear(item); + } +#ifndef NDEBUG + for (; i < ns->maxitems; i++) { + item = &ns->items[i]; + assert(!_sharednsitem_is_initialized(item)); + assert(!_sharednsitem_has_value(item, NULL)); + } +#endif + PyMem_RawFree(ns); } -static int -_sharedns_init(_PyXI_namespace *ns, PyObject *names) +static _PyXI_namespace * +_create_sharedns(PyObject *names) { - assert(!_sharedns_is_initialized(ns)); assert(names != NULL); - Py_ssize_t len = PyDict_CheckExact(names) + Py_ssize_t numnames = PyDict_CheckExact(names) ? PyDict_Size(names) : PySequence_Size(names); - if (len < 0) { - return -1; - } - if (len == 0) { - PyErr_SetString(PyExc_ValueError, "empty namespaces not allowed"); - return -1; - } - assert(len > 0); - // Allocate the items. - _PyXI_namespace_item *items = - PyMem_RawCalloc(sizeof(struct _sharednsitem), len); - if (items == NULL) { - PyErr_NoMemory(); - return -1; + _PyXI_namespace *ns = _sharedns_alloc(numnames); + if (ns == NULL) { + return NULL; } + _PyXI_namespace_item *items = ns->items; // Fill in the names. - Py_ssize_t i = -1; if (PyDict_CheckExact(names)) { + Py_ssize_t i = 0; Py_ssize_t pos = 0; - for (i=0; i < len; i++) { - PyObject *key; - if (!PyDict_Next(names, &pos, &key, NULL)) { - // This should not be possible. - assert(0); - goto error; - } - if (_sharednsitem_init(&items[i], key) < 0) { + PyObject *name; + while(PyDict_Next(names, &pos, &name, NULL)) { + if (_sharednsitem_init(&items[i], name) < 0) { goto error; } + ns->numnames += 1; + i += 1; } } else if (PySequence_Check(names)) { - for (i=0; i < len; i++) { - PyObject *key = PySequence_GetItem(names, i); - if (key == NULL) { + for (Py_ssize_t i = 0; i < numnames; i++) { + PyObject *name = PySequence_GetItem(names, i); + if (name == NULL) { goto error; } - int res = _sharednsitem_init(&items[i], key); - Py_DECREF(key); + int res = _sharednsitem_init(&items[i], name); + Py_DECREF(name); if (res < 0) { goto error; } + ns->numnames += 1; } } else { @@ -2027,140 +2132,79 @@ _sharedns_init(_PyXI_namespace *ns, PyObject *names) "non-sequence namespace not supported"); goto error; } - - ns->items = items; - ns->len = len; - assert(_sharedns_is_initialized(ns)); - return 0; + assert(ns->numnames == ns->maxitems); + return ns; error: - for (Py_ssize_t j=0; j < i; j++) { - _sharednsitem_clear(&items[j]); - } - PyMem_RawFree(items); - assert(!_sharedns_is_initialized(ns)); - return -1; -} - -void -_PyXI_FreeNamespace(_PyXI_namespace *ns) -{ - if (!_sharedns_is_initialized(ns)) { - return; - } - - int64_t interpid = -1; - if (!_sharedns_has_xidata(ns, &interpid)) { - _sharedns_free(ns); - return; - } - - if (interpid == PyInterpreterState_GetID(PyInterpreterState_Get())) { - _sharedns_free(ns); - } - else { - // If we weren't always dynamically allocating the cross-interpreter - // data in each item then we would need to using a pending call - // to call _sharedns_free(), to avoid the race between freeing - // the shared namespace and releasing the XI data. - _sharedns_free(ns); - } -} - -_PyXI_namespace * -_PyXI_NamespaceFromNames(PyObject *names) -{ - if (names == NULL || names == Py_None) { - return NULL; - } - - _PyXI_namespace *ns = _sharedns_new(); - if (ns == NULL) { - return NULL; - } - - if (_sharedns_init(ns, names) < 0) { - PyMem_RawFree(ns); - if (PySequence_Size(names) == 0) { - PyErr_Clear(); - } - return NULL; - } - - return ns; + _sharedns_free(ns); + return NULL; } -#ifndef NDEBUG -static int _session_is_active(_PyXI_session *); -#endif static void _propagate_not_shareable_error(_PyXI_session *); -int -_PyXI_FillNamespaceFromDict(_PyXI_namespace *ns, PyObject *nsobj, - _PyXI_session *session) -{ - // session must be entered already, if provided. - assert(session == NULL || _session_is_active(session)); - assert(_sharedns_is_initialized(ns)); - for (Py_ssize_t i=0; i < ns->len; i++) { - _PyXI_namespace_item *item = &ns->items[i]; - if (_sharednsitem_copy_from_ns(item, nsobj) < 0) { +static int +_fill_sharedns(_PyXI_namespace *ns, PyObject *nsobj, _PyXI_session *session) +{ + // All items are expected to be shareable. + assert(_sharedns_check_counts(ns)); + assert(ns->numnames == ns->maxitems); + assert(ns->numvalues == 0); + for (Py_ssize_t i=0; i < ns->maxitems; i++) { + if (_sharednsitem_copy_from_ns(&ns->items[i], nsobj) < 0) { _propagate_not_shareable_error(session); // Clear out the ones we set so far. for (Py_ssize_t j=0; j < i; j++) { _sharednsitem_clear_value(&ns->items[j]); + ns->numvalues -= 1; } return -1; } + ns->numvalues += 1; } return 0; } -// All items are expected to be shareable. -static _PyXI_namespace * -_PyXI_NamespaceFromDict(PyObject *nsobj, _PyXI_session *session) +static int +_sharedns_free_pending(void *data) { - // session must be entered already, if provided. - assert(session == NULL || _session_is_active(session)); - if (nsobj == NULL || nsobj == Py_None) { - return NULL; - } - if (!PyDict_CheckExact(nsobj)) { - PyErr_SetString(PyExc_TypeError, "expected a dict"); - return NULL; - } + _sharedns_free((_PyXI_namespace *)data); + return 0; +} - _PyXI_namespace *ns = _sharedns_new(); - if (ns == NULL) { - return NULL; +static void +_destroy_sharedns(_PyXI_namespace *ns) +{ + assert(_sharedns_check_counts(ns)); + assert(ns->numnames == ns->maxitems); + if (ns->numvalues == 0) { + _sharedns_free(ns); + return; } - if (_sharedns_init(ns, nsobj) < 0) { - if (PyDict_Size(nsobj) == 0) { - PyMem_RawFree(ns); - PyErr_Clear(); - return NULL; - } - goto error; + int64_t interpid0; + if (!_sharednsitem_has_value(&ns->items[0], &interpid0)) { + // This shouldn't have been possible. + // We can deal with it in _sharedns_free(). + _sharedns_free(ns); + return; } - - if (_PyXI_FillNamespaceFromDict(ns, nsobj, session) < 0) { - goto error; + PyInterpreterState *interp = _PyInterpreterState_LookUpID(interpid0); + if (interp == PyInterpreterState_Get()) { + _sharedns_free(ns); + return; } - return ns; - -error: - assert(PyErr_Occurred() - || (session != NULL && session->error_override != NULL)); - _sharedns_free(ns); - return NULL; + // One or more items may have interpreter-specific data. + // Currently the xidata for each value is dynamically allocated, + // so technically we don't need to worry about that. + // However, explicitly adding a pending call here is simpler. + (void)_Py_CallInInterpreter(interp, _sharedns_free_pending, ns); } -int -_PyXI_ApplyNamespace(_PyXI_namespace *ns, PyObject *nsobj, PyObject *dflt) +static int +_apply_sharedns(_PyXI_namespace *ns, PyObject *nsobj, PyObject *dflt) { - for (Py_ssize_t i=0; i < ns->len; i++) { + for (Py_ssize_t i=0; i < ns->maxitems; i++) { if (_sharednsitem_apply(&ns->items[i], nsobj, dflt) != 0) { return -1; } @@ -2169,9 +2213,79 @@ _PyXI_ApplyNamespace(_PyXI_namespace *ns, PyObject *nsobj, PyObject *dflt) } -/**********************/ -/* high-level helpers */ -/**********************/ +/*********************************/ +/* switched-interpreter sessions */ +/*********************************/ + +struct xi_session { +#define SESSION_UNUSED 0 +#define SESSION_ACTIVE 1 + int status; + int switched; + + // Once a session has been entered, this is the tstate that was + // current before the session. If it is different from cur_tstate + // then we must have switched interpreters. Either way, this will + // be the current tstate once we exit the session. + PyThreadState *prev_tstate; + // Once a session has been entered, this is the current tstate. + // It must be current when the session exits. + PyThreadState *init_tstate; + // This is true if init_tstate needs cleanup during exit. + int own_init_tstate; + + // This is true if, while entering the session, init_thread took + // "ownership" of the interpreter's __main__ module. This means + // it is the only thread that is allowed to run code there. + // (Caveat: for now, users may still run exec() against the + // __main__ module's dict, though that isn't advisable.) + int running; + // This is a cached reference to the __dict__ of the entered + // interpreter's __main__ module. It is looked up when at the + // beginning of the session as a convenience. + PyObject *main_ns; + + // This is set if the interpreter is entered and raised an exception + // that needs to be handled in some special way during exit. + _PyXI_errcode *error_override; + // This is set if exit captured an exception to propagate. + _PyXI_error *error; + + // -- pre-allocated memory -- + _PyXI_error _error; + _PyXI_errcode _error_override; +}; + + +_PyXI_session * +_PyXI_NewSession(void) +{ + _PyXI_session *session = PyMem_RawCalloc(1, sizeof(_PyXI_session)); + if (session == NULL) { + PyErr_NoMemory(); + return NULL; + } + return session; +} + +void +_PyXI_FreeSession(_PyXI_session *session) +{ + assert(session->status == SESSION_UNUSED); + PyMem_RawFree(session); +} + + +static inline int +_session_is_active(_PyXI_session *session) +{ + return session->status == SESSION_ACTIVE; +} + +static int _ensure_main_ns(_PyXI_session *); +static inline void _session_set_error(_PyXI_session *, _PyXI_errcode); +static void _capture_current_exception(_PyXI_session *); + /* enter/exit a cross-interpreter session */ @@ -2179,6 +2293,7 @@ static void _enter_session(_PyXI_session *session, PyInterpreterState *interp) { // Set here and cleared in _exit_session(). + assert(session->status == SESSION_UNUSED); assert(!session->own_init_tstate); assert(session->init_tstate == NULL); assert(session->prev_tstate == NULL); @@ -2193,15 +2308,22 @@ _enter_session(_PyXI_session *session, PyInterpreterState *interp) // Switch to interpreter. PyThreadState *tstate = PyThreadState_Get(); PyThreadState *prev = tstate; - if (interp != tstate->interp) { + int same_interp = (interp == tstate->interp); + if (!same_interp) { tstate = _PyThreadState_NewBound(interp, _PyThreadState_WHENCE_EXEC); // XXX Possible GILState issues? - session->prev_tstate = PyThreadState_Swap(tstate); - assert(session->prev_tstate == prev); - session->own_init_tstate = 1; + PyThreadState *swapped = PyThreadState_Swap(tstate); + assert(swapped == prev); + (void)swapped; } - session->init_tstate = tstate; - session->prev_tstate = prev; + + *session = (_PyXI_session){ + .status = SESSION_ACTIVE, + .switched = !same_interp, + .init_tstate = tstate, + .prev_tstate = prev, + .own_init_tstate = !same_interp, + }; } static void @@ -2212,9 +2334,7 @@ _exit_session(_PyXI_session *session) assert(PyThreadState_Get() == tstate); // Release any of the entered interpreters resources. - if (session->main_ns != NULL) { - Py_CLEAR(session->main_ns); - } + Py_CLEAR(session->main_ns); // Ensure this thread no longer owns __main__. if (session->running) { @@ -2235,17 +2355,15 @@ _exit_session(_PyXI_session *session) else { assert(!session->own_init_tstate); } - session->prev_tstate = NULL; - session->init_tstate = NULL; -} -#ifndef NDEBUG -static int -_session_is_active(_PyXI_session *session) -{ - return (session->init_tstate != NULL); + // For now the error data persists past the exit. + *session = (_PyXI_session){ + .error_override = session->error_override, + .error = session->error, + ._error = session->_error, + ._error_override = session->_error_override, + }; } -#endif static void _propagate_not_shareable_error(_PyXI_session *session) @@ -2262,11 +2380,102 @@ _propagate_not_shareable_error(_PyXI_session *session) } if (PyErr_ExceptionMatches(exctype)) { // We want to propagate the exception directly. - session->_error_override = _PyXI_ERR_NOT_SHAREABLE; - session->error_override = &session->_error_override; + _session_set_error(session, _PyXI_ERR_NOT_SHAREABLE); + } +} + +PyObject * +_PyXI_ApplyCapturedException(_PyXI_session *session) +{ + assert(!PyErr_Occurred()); + assert(session->error != NULL); + PyObject *res = _PyXI_ApplyError(session->error); + assert((res == NULL) != (PyErr_Occurred() == NULL)); + session->error = NULL; + return res; +} + +int +_PyXI_HasCapturedException(_PyXI_session *session) +{ + return session->error != NULL; +} + +int +_PyXI_Enter(_PyXI_session *session, + PyInterpreterState *interp, PyObject *nsupdates) +{ + // Convert the attrs for cross-interpreter use. + _PyXI_namespace *sharedns = NULL; + if (nsupdates != NULL) { + Py_ssize_t len = PyDict_Size(nsupdates); + if (len < 0) { + return -1; + } + if (len > 0) { + sharedns = _create_sharedns(nsupdates); + if (sharedns == NULL) { + return -1; + } + if (_fill_sharedns(sharedns, nsupdates, NULL) < 0) { + assert(session->error == NULL); + _destroy_sharedns(sharedns); + return -1; + } + } + } + + // Switch to the requested interpreter (if necessary). + _enter_session(session, interp); + _PyXI_errcode errcode = _PyXI_ERR_UNCAUGHT_EXCEPTION; + + // Ensure this thread owns __main__. + if (_PyInterpreterState_SetRunningMain(interp) < 0) { + // In the case where we didn't switch interpreters, it would + // be more efficient to leave the exception in place and return + // immediately. However, life is simpler if we don't. + errcode = _PyXI_ERR_ALREADY_RUNNING; + goto error; + } + session->running = 1; + + // Apply the cross-interpreter data. + if (sharedns != NULL) { + if (_ensure_main_ns(session) < 0) { + errcode = _PyXI_ERR_MAIN_NS_FAILURE; + goto error; + } + if (_apply_sharedns(sharedns, session->main_ns, NULL) < 0) { + errcode = _PyXI_ERR_APPLY_NS_FAILURE; + goto error; + } + _destroy_sharedns(sharedns); + } + + errcode = _PyXI_ERR_NO_ERROR; + assert(!PyErr_Occurred()); + return 0; + +error: + // We want to propagate all exceptions here directly (best effort). + _session_set_error(session, errcode); + _exit_session(session); + if (sharedns != NULL) { + _destroy_sharedns(sharedns); } + return -1; +} + +void +_PyXI_Exit(_PyXI_session *session) +{ + _capture_current_exception(session); + _exit_session(session); } + +/* in an active cross-interpreter session */ + static void _capture_current_exception(_PyXI_session *session) { @@ -2328,100 +2537,55 @@ _capture_current_exception(_PyXI_session *session) // Finished! assert(!PyErr_Occurred()); - session->error = err; + session->error = err; } -PyObject * -_PyXI_ApplyCapturedException(_PyXI_session *session) -{ - assert(!PyErr_Occurred()); - assert(session->error != NULL); - PyObject *res = _PyXI_ApplyError(session->error); - assert((res == NULL) != (PyErr_Occurred() == NULL)); - session->error = NULL; - return res; -} - -int -_PyXI_HasCapturedException(_PyXI_session *session) +static inline void +_session_set_error(_PyXI_session *session, _PyXI_errcode errcode) { - return session->error != NULL; + assert(_session_is_active(session)); + assert(PyErr_Occurred()); + if (errcode != _PyXI_ERR_UNCAUGHT_EXCEPTION) { + session->_error_override = errcode; + session->error_override = &session->_error_override; + } + _capture_current_exception(session); } -int -_PyXI_Enter(_PyXI_session *session, - PyInterpreterState *interp, PyObject *nsupdates) +static int +_ensure_main_ns(_PyXI_session *session) { - // Convert the attrs for cross-interpreter use. - _PyXI_namespace *sharedns = NULL; - if (nsupdates != NULL) { - sharedns = _PyXI_NamespaceFromDict(nsupdates, NULL); - if (sharedns == NULL && PyErr_Occurred()) { - assert(session->error == NULL); - return -1; - } - } - - // Switch to the requested interpreter (if necessary). - _enter_session(session, interp); - PyThreadState *session_tstate = session->init_tstate; - _PyXI_errcode errcode = _PyXI_ERR_UNCAUGHT_EXCEPTION; - - // Ensure this thread owns __main__. - if (_PyInterpreterState_SetRunningMain(interp) < 0) { - // In the case where we didn't switch interpreters, it would - // be more efficient to leave the exception in place and return - // immediately. However, life is simpler if we don't. - errcode = _PyXI_ERR_ALREADY_RUNNING; - goto error; + assert(_session_is_active(session)); + if (session->main_ns != NULL) { + return 0; } - session->running = 1; - // Cache __main__.__dict__. - PyObject *main_mod = _Py_GetMainModule(session_tstate); + PyObject *main_mod = _Py_GetMainModule(session->init_tstate); if (_Py_CheckMainModule(main_mod) < 0) { - errcode = _PyXI_ERR_MAIN_NS_FAILURE; - goto error; + return -1; } PyObject *ns = PyModule_GetDict(main_mod); // borrowed Py_DECREF(main_mod); if (ns == NULL) { - errcode = _PyXI_ERR_MAIN_NS_FAILURE; - goto error; + return -1; } session->main_ns = Py_NewRef(ns); - - // Apply the cross-interpreter data. - if (sharedns != NULL) { - if (_PyXI_ApplyNamespace(sharedns, ns, NULL) < 0) { - errcode = _PyXI_ERR_APPLY_NS_FAILURE; - goto error; - } - _PyXI_FreeNamespace(sharedns); - } - - errcode = _PyXI_ERR_NO_ERROR; - assert(!PyErr_Occurred()); return 0; - -error: - assert(PyErr_Occurred()); - // We want to propagate all exceptions here directly (best effort). - assert(errcode != _PyXI_ERR_UNCAUGHT_EXCEPTION); - session->error_override = &errcode; - _capture_current_exception(session); - _exit_session(session); - if (sharedns != NULL) { - _PyXI_FreeNamespace(sharedns); - } - return -1; } -void -_PyXI_Exit(_PyXI_session *session) +PyObject * +_PyXI_GetMainNamespace(_PyXI_session *session) { - _capture_current_exception(session); - _exit_session(session); + if (!_session_is_active(session)) { + PyErr_SetString(PyExc_RuntimeError, "session not active"); + return NULL; + } + if (_ensure_main_ns(session) < 0) { + _session_set_error(session, _PyXI_ERR_MAIN_NS_FAILURE); + _capture_current_exception(session); + return NULL; + } + return session->main_ns; } diff --git a/Python/crossinterp_data_lookup.h b/Python/crossinterp_data_lookup.h index d69927dbcd387f..ba43e9d84084da 100644 --- a/Python/crossinterp_data_lookup.h +++ b/Python/crossinterp_data_lookup.h @@ -12,7 +12,8 @@ typedef _PyXIData_regitem_t dlregitem_t; // forward static void _xidregistry_init(dlregistry_t *); static void _xidregistry_fini(dlregistry_t *); -static xidatafunc _lookup_getdata_from_registry(dlcontext_t *, PyObject *); +static _PyXIData_getdata_t _lookup_getdata_from_registry( + dlcontext_t *, PyObject *); /* used in crossinterp.c */ @@ -49,7 +50,7 @@ get_lookup_context(PyThreadState *tstate, dlcontext_t *res) return 0; } -static xidatafunc +static _PyXIData_getdata_t lookup_getdata(dlcontext_t *ctx, PyObject *obj) { /* Cross-interpreter objects are looked up by exact match on the class. @@ -88,24 +89,24 @@ _PyXIData_FormatNotShareableError(PyThreadState *tstate, } -xidatafunc +_PyXIData_getdata_t _PyXIData_Lookup(PyThreadState *tstate, PyObject *obj) { dlcontext_t ctx; if (get_lookup_context(tstate, &ctx) < 0) { - return NULL; + return (_PyXIData_getdata_t){0}; } return lookup_getdata(&ctx, obj); } /***********************************************/ -/* a registry of {type -> xidatafunc} */ +/* a registry of {type -> _PyXIData_getdata_t} */ /***********************************************/ -/* For now we use a global registry of shareable classes. An - alternative would be to add a tp_* slot for a class's - xidatafunc. It would be simpler and more efficient. */ +/* For now we use a global registry of shareable classes. + An alternative would be to add a tp_* slot for a class's + _PyXIData_getdata_t. It would be simpler and more efficient. */ /* registry lifecycle */ @@ -200,7 +201,7 @@ _xidregistry_find_type(dlregistry_t *xidregistry, PyTypeObject *cls) return NULL; } -static xidatafunc +static _PyXIData_getdata_t _lookup_getdata_from_registry(dlcontext_t *ctx, PyObject *obj) { PyTypeObject *cls = Py_TYPE(obj); @@ -209,10 +210,12 @@ _lookup_getdata_from_registry(dlcontext_t *ctx, PyObject *obj) _xidregistry_lock(xidregistry); dlregitem_t *matched = _xidregistry_find_type(xidregistry, cls); - xidatafunc func = matched != NULL ? matched->getdata : NULL; + _PyXIData_getdata_t getdata = matched != NULL + ? matched->getdata + : (_PyXIData_getdata_t){0}; _xidregistry_unlock(xidregistry); - return func; + return getdata; } @@ -220,12 +223,13 @@ _lookup_getdata_from_registry(dlcontext_t *ctx, PyObject *obj) static int _xidregistry_add_type(dlregistry_t *xidregistry, - PyTypeObject *cls, xidatafunc getdata) + PyTypeObject *cls, _PyXIData_getdata_t getdata) { dlregitem_t *newhead = PyMem_RawMalloc(sizeof(dlregitem_t)); if (newhead == NULL) { return -1; } + assert((getdata.basic == NULL) != (getdata.fallback == NULL)); *newhead = (dlregitem_t){ // We do not keep a reference, to avoid keeping the class alive. .cls = cls, @@ -283,13 +287,13 @@ _xidregistry_clear(dlregistry_t *xidregistry) int _PyXIData_RegisterClass(PyThreadState *tstate, - PyTypeObject *cls, xidatafunc getdata) + PyTypeObject *cls, _PyXIData_getdata_t getdata) { if (!PyType_Check(cls)) { PyErr_Format(PyExc_ValueError, "only classes may be registered"); return -1; } - if (getdata == NULL) { + if (getdata.basic == NULL && getdata.fallback == NULL) { PyErr_Format(PyExc_ValueError, "missing 'getdata' func"); return -1; } @@ -304,7 +308,8 @@ _PyXIData_RegisterClass(PyThreadState *tstate, dlregitem_t *matched = _xidregistry_find_type(xidregistry, cls); if (matched != NULL) { - assert(matched->getdata == getdata); + assert(matched->getdata.basic == getdata.basic); + assert(matched->getdata.fallback == getdata.fallback); matched->refcount += 1; goto finally; } @@ -608,7 +613,8 @@ _tuple_shared_free(void* data) } static int -_tuple_shared(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata) +_tuple_shared(PyThreadState *tstate, PyObject *obj, xidata_fallback_t fallback, + _PyXIData_t *xidata) { Py_ssize_t len = PyTuple_GET_SIZE(obj); if (len < 0) { @@ -636,7 +642,8 @@ _tuple_shared(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata) int res = -1; if (!_Py_EnterRecursiveCallTstate(tstate, " while sharing a tuple")) { - res = _PyObject_GetXIData(tstate, item, xidata_i); + res = _PyObject_GetXIDataWithFallback( + tstate, item, fallback, xidata_i); _Py_LeaveRecursiveCallTstate(tstate); } if (res < 0) { @@ -737,40 +744,48 @@ _PyFunction_GetXIData(PyThreadState *tstate, PyObject *func, static void _register_builtins_for_crossinterpreter_data(dlregistry_t *xidregistry) { +#define REGISTER(TYPE, GETDATA) \ + _xidregistry_add_type(xidregistry, (PyTypeObject *)TYPE, \ + ((_PyXIData_getdata_t){.basic=(GETDATA)})) +#define REGISTER_FALLBACK(TYPE, GETDATA) \ + _xidregistry_add_type(xidregistry, (PyTypeObject *)TYPE, \ + ((_PyXIData_getdata_t){.fallback=(GETDATA)})) // None - if (_xidregistry_add_type(xidregistry, (PyTypeObject *)PyObject_Type(Py_None), _none_shared) != 0) { + if (REGISTER(Py_TYPE(Py_None), _none_shared) != 0) { Py_FatalError("could not register None for cross-interpreter sharing"); } // int - if (_xidregistry_add_type(xidregistry, &PyLong_Type, _long_shared) != 0) { + if (REGISTER(&PyLong_Type, _long_shared) != 0) { Py_FatalError("could not register int for cross-interpreter sharing"); } // bytes - if (_xidregistry_add_type(xidregistry, &PyBytes_Type, _PyBytes_GetXIData) != 0) { + if (REGISTER(&PyBytes_Type, _PyBytes_GetXIData) != 0) { Py_FatalError("could not register bytes for cross-interpreter sharing"); } // str - if (_xidregistry_add_type(xidregistry, &PyUnicode_Type, _str_shared) != 0) { + if (REGISTER(&PyUnicode_Type, _str_shared) != 0) { Py_FatalError("could not register str for cross-interpreter sharing"); } // bool - if (_xidregistry_add_type(xidregistry, &PyBool_Type, _bool_shared) != 0) { + if (REGISTER(&PyBool_Type, _bool_shared) != 0) { Py_FatalError("could not register bool for cross-interpreter sharing"); } // float - if (_xidregistry_add_type(xidregistry, &PyFloat_Type, _float_shared) != 0) { + if (REGISTER(&PyFloat_Type, _float_shared) != 0) { Py_FatalError("could not register float for cross-interpreter sharing"); } // tuple - if (_xidregistry_add_type(xidregistry, &PyTuple_Type, _tuple_shared) != 0) { + if (REGISTER_FALLBACK(&PyTuple_Type, _tuple_shared) != 0) { Py_FatalError("could not register tuple for cross-interpreter sharing"); } // For now, we do not register PyCode_Type or PyFunction_Type. +#undef REGISTER +#undef REGISTER_FALLBACK }