Skip to content

gh-84570: Add Timeouts to SendChannel.send() and RecvChannel.recv() #110567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 17, 2023
6 changes: 6 additions & 0 deletions Include/internal/pycore_pythread.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ extern int _PyThread_at_fork_reinit(PyThread_type_lock *lock);
// unset: -1 seconds, in nanoseconds
#define PyThread_UNSET_TIMEOUT ((_PyTime_t)(-1 * 1000 * 1000 * 1000))

// Exported for the _xxinterpchannels module.
PyAPI_FUNC(int) PyThread_ParseTimeoutArg(
PyObject *arg,
int blocking,
PY_TIMEOUT_T *timeout);

/* Helper to acquire an interruptible lock with a timeout. If the lock acquire
* is interrupted, signal handlers are run, and if they raise an exception,
* PY_LOCK_INTR is returned. Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE
Expand Down
20 changes: 15 additions & 5 deletions Lib/test/support/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,25 @@ class RecvChannel(_ChannelEnd):

_end = 'recv'

def recv(self, *, _sentinel=object(), _delay=10 / 1000): # 10 milliseconds
def recv(self, timeout=None, *,
_sentinel=object(),
_delay=10 / 1000, # 10 milliseconds
):
"""Return the next object from the channel.

This blocks until an object has been sent, if none have been
sent already.
"""
if timeout is not None:
timeout = int(timeout)
if timeout < 0:
raise ValueError(f'timeout value must be non-negative')
end = time.time() + timeout
obj = _channels.recv(self._id, _sentinel)
while obj is _sentinel:
time.sleep(_delay)
if timeout is not None and time.time() >= end:
raise TimeoutError
obj = _channels.recv(self._id, _sentinel)
return obj

Expand All @@ -203,12 +213,12 @@ class SendChannel(_ChannelEnd):

_end = 'send'

def send(self, obj):
def send(self, obj, timeout=None):
"""Send the object (i.e. its data) to the channel's receiving end.

This blocks until the object is received.
"""
_channels.send(self._id, obj, blocking=True)
_channels.send(self._id, obj, timeout=timeout, blocking=True)

def send_nowait(self, obj):
"""Send the object to the channel's receiving end.
Expand All @@ -221,12 +231,12 @@ def send_nowait(self, obj):
# See bpo-32604 and gh-19829.
return _channels.send(self._id, obj, blocking=False)

def send_buffer(self, obj):
def send_buffer(self, obj, timeout=None):
"""Send the object's buffer to the channel's receiving end.

This blocks until the object is received.
"""
_channels.send_buffer(self._id, obj, blocking=True)
_channels.send_buffer(self._id, obj, timeout=timeout, blocking=True)

def send_buffer_nowait(self, obj):
"""Send the object's buffer to the channel's receiving end.
Expand Down
128 changes: 108 additions & 20 deletions Lib/test/test__xxinterpchannels.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,22 +864,97 @@ def f():

self.assertEqual(received, obj)

def test_send_timeout(self):
obj = b'spam'

with self.subTest('non-blocking with timeout'):
cid = channels.create()
with self.assertRaises(ValueError):
channels.send(cid, obj, blocking=False, timeout=0.1)

with self.subTest('timeout hit'):
cid = channels.create()
with self.assertRaises(TimeoutError):
channels.send(cid, obj, blocking=True, timeout=0.1)
with self.assertRaises(channels.ChannelEmptyError):
received = channels.recv(cid)
print(repr(received))

with self.subTest('timeout not hit'):
cid = channels.create()
def f():
recv_wait(cid)
t = threading.Thread(target=f)
t.start()
channels.send(cid, obj, blocking=True, timeout=10)
t.join()

def test_send_buffer_timeout(self):
try:
self._has_run_once_timeout
except AttributeError:
# At the moment, this test leaks a few references.
# It looks like the leak originates with the addition
# of _channels.send_buffer() (gh-110246), whereas the
# tests were added afterward. We want this test even
# if the refleak isn't fixed yet, so we skip here.
raise unittest.SkipTest('temporarily skipped due to refleaks')
else:
self._has_run_once_timeout = True

obj = bytearray(b'spam')

with self.subTest('non-blocking with timeout'):
cid = channels.create()
with self.assertRaises(ValueError):
channels.send_buffer(cid, obj, blocking=False, timeout=0.1)

with self.subTest('timeout hit'):
cid = channels.create()
with self.assertRaises(TimeoutError):
channels.send_buffer(cid, obj, blocking=True, timeout=0.1)
with self.assertRaises(channels.ChannelEmptyError):
received = channels.recv(cid)
print(repr(received))

with self.subTest('timeout not hit'):
cid = channels.create()
def f():
recv_wait(cid)
t = threading.Thread(target=f)
t.start()
channels.send_buffer(cid, obj, blocking=True, timeout=10)
t.join()

def test_send_closed_while_waiting(self):
obj = b'spam'
wait = self.build_send_waiter(obj)
cid = channels.create()
def f():
wait()
channels.close(cid, force=True)
t = threading.Thread(target=f)
t.start()
with self.assertRaises(channels.ChannelClosedError):
channels.send(cid, obj, blocking=True)
t.join()

with self.subTest('without timeout'):
cid = channels.create()
def f():
wait()
channels.close(cid, force=True)
t = threading.Thread(target=f)
t.start()
with self.assertRaises(channels.ChannelClosedError):
channels.send(cid, obj, blocking=True)
t.join()

with self.subTest('with timeout'):
cid = channels.create()
def f():
wait()
channels.close(cid, force=True)
t = threading.Thread(target=f)
t.start()
with self.assertRaises(channels.ChannelClosedError):
channels.send(cid, obj, blocking=True, timeout=30)
t.join()

def test_send_buffer_closed_while_waiting(self):
try:
self._has_run_once
self._has_run_once_closed
except AttributeError:
# At the moment, this test leaks a few references.
# It looks like the leak originates with the addition
Expand All @@ -888,19 +963,32 @@ def test_send_buffer_closed_while_waiting(self):
# if the refleak isn't fixed yet, so we skip here.
raise unittest.SkipTest('temporarily skipped due to refleaks')
else:
self._has_run_once = True
self._has_run_once_closed = True

obj = bytearray(b'spam')
wait = self.build_send_waiter(obj, buffer=True)
cid = channels.create()
def f():
wait()
channels.close(cid, force=True)
t = threading.Thread(target=f)
t.start()
with self.assertRaises(channels.ChannelClosedError):
channels.send_buffer(cid, obj, blocking=True)
t.join()

with self.subTest('without timeout'):
cid = channels.create()
def f():
wait()
channels.close(cid, force=True)
t = threading.Thread(target=f)
t.start()
with self.assertRaises(channels.ChannelClosedError):
channels.send_buffer(cid, obj, blocking=True)
t.join()

with self.subTest('with timeout'):
cid = channels.create()
def f():
wait()
channels.close(cid, force=True)
t = threading.Thread(target=f)
t.start()
with self.assertRaises(channels.ChannelClosedError):
channels.send_buffer(cid, obj, blocking=True, timeout=30)
t.join()

#-------------------
# close
Expand Down
5 changes: 5 additions & 0 deletions Lib/test/test_interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,11 @@ def test_send_recv_nowait_different_interpreters(self):
self.assertEqual(obj2, b'eggs')
self.assertNotEqual(id(obj2), int(out))

def test_recv_timeout(self):
r, _ = interpreters.create_channel()
with self.assertRaises(TimeoutError):
r.recv(timeout=1)

def test_recv_channel_does_not_exist(self):
ch = interpreters.RecvChannel(1_000_000)
with self.assertRaises(interpreters.ChannelNotFoundError):
Expand Down
2 changes: 2 additions & 0 deletions Modules/_queuemodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ _queue_SimpleQueue_get_impl(simplequeueobject *self, PyTypeObject *cls,
PY_TIMEOUT_T microseconds;
PyThreadState *tstate = PyThreadState_Get();

// XXX Use PyThread_ParseTimeoutArg().

if (block == 0) {
/* Non-blocking */
microseconds = 0;
Expand Down
11 changes: 6 additions & 5 deletions Modules/_threadmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ lock_acquire_parse_args(PyObject *args, PyObject *kwds,
char *kwlist[] = {"blocking", "timeout", NULL};
int blocking = 1;
PyObject *timeout_obj = NULL;
const _PyTime_t unset_timeout = _PyTime_FromSeconds(-1);

*timeout = unset_timeout ;

if (!PyArg_ParseTupleAndKeywords(args, kwds, "|pO:acquire", kwlist,
&blocking, &timeout_obj))
return -1;

// XXX Use PyThread_ParseTimeoutArg().

const _PyTime_t unset_timeout = _PyTime_FromSeconds(-1);
*timeout = unset_timeout;

if (timeout_obj
&& _PyTime_FromSecondsObject(timeout,
timeout_obj, _PyTime_ROUND_TIMEOUT) < 0)
Expand All @@ -108,7 +109,7 @@ lock_acquire_parse_args(PyObject *args, PyObject *kwds,
}
if (*timeout < 0 && *timeout != unset_timeout) {
PyErr_SetString(PyExc_ValueError,
"timeout value must be positive");
"timeout value must be a non-negative number");
return -1;
}
if (!blocking)
Expand Down
43 changes: 26 additions & 17 deletions Modules/_xxinterpchannelsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,8 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
}

static int
wait_for_lock(PyThread_type_lock mutex)
wait_for_lock(PyThread_type_lock mutex, PY_TIMEOUT_T timeout)
{
PY_TIMEOUT_T timeout = PyThread_UNSET_TIMEOUT;
PyLockStatus res = PyThread_acquire_lock_timed_with_retries(mutex, timeout);
if (res == PY_LOCK_INTR) {
/* KeyboardInterrupt, etc. */
Expand Down Expand Up @@ -1883,7 +1882,8 @@ _channel_clear_sent(_channels *channels, int64_t cid, _waiting_t *waiting)
}

static int
_channel_send_wait(_channels *channels, int64_t cid, PyObject *obj)
_channel_send_wait(_channels *channels, int64_t cid, PyObject *obj,
PY_TIMEOUT_T timeout)
{
// We use a stack variable here, so we must ensure that &waiting
// is not held by any channel item at the point this function exits.
Expand All @@ -1901,7 +1901,7 @@ _channel_send_wait(_channels *channels, int64_t cid, PyObject *obj)
}

/* Wait until the object is received. */
if (wait_for_lock(waiting.mutex) < 0) {
if (wait_for_lock(waiting.mutex, timeout) < 0) {
assert(PyErr_Occurred());
_waiting_finish_releasing(&waiting);
/* The send() call is failing now, so make sure the item
Expand Down Expand Up @@ -2816,25 +2816,29 @@ receive end.");
static PyObject *
channel_send(PyObject *self, PyObject *args, PyObject *kwds)
{
// XXX Add a timeout arg.
static char *kwlist[] = {"cid", "obj", "blocking", NULL};
int64_t cid;
static char *kwlist[] = {"cid", "obj", "blocking", "timeout", NULL};
struct channel_id_converter_data cid_data = {
.module = self,
};
PyObject *obj;
int blocking = 1;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$p:channel_send", kwlist,
PyObject *timeout_obj = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$pO:channel_send", kwlist,
channel_id_converter, &cid_data, &obj,
&blocking)) {
&blocking, &timeout_obj)) {
return NULL;
}

int64_t cid = cid_data.cid;
PY_TIMEOUT_T timeout;
if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) {
return NULL;
}
cid = cid_data.cid;

/* Queue up the object. */
int err = 0;
if (blocking) {
err = _channel_send_wait(&_globals.channels, cid, obj);
err = _channel_send_wait(&_globals.channels, cid, obj, timeout);
}
else {
err = _channel_send(&_globals.channels, cid, obj, NULL);
Expand All @@ -2855,20 +2859,25 @@ By default this waits for the object to be received.");
static PyObject *
channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"cid", "obj", "blocking", NULL};
int64_t cid;
static char *kwlist[] = {"cid", "obj", "blocking", "timeout", NULL};
struct channel_id_converter_data cid_data = {
.module = self,
};
PyObject *obj;
int blocking = 1;
PyObject *timeout_obj = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds,
"O&O|$p:channel_send_buffer", kwlist,
"O&O|$pO:channel_send_buffer", kwlist,
channel_id_converter, &cid_data, &obj,
&blocking)) {
&blocking, &timeout_obj)) {
return NULL;
}

int64_t cid = cid_data.cid;
PY_TIMEOUT_T timeout;
if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) {
return NULL;
}
cid = cid_data.cid;

PyObject *tempobj = PyMemoryView_FromObject(obj);
if (tempobj == NULL) {
Expand All @@ -2878,7 +2887,7 @@ channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
/* Queue up the object. */
int err = 0;
if (blocking) {
err = _channel_send_wait(&_globals.channels, cid, tempobj);
err = _channel_send_wait(&_globals.channels, cid, tempobj, timeout);
}
else {
err = _channel_send(&_globals.channels, cid, tempobj, NULL);
Expand Down
Loading