Skip to content

Commit f37a983

Browse files
[3.8] bpo-38005: Fixed comparing and creating of InterpreterID and ChannelID. (GH-15652) (GH-16145)
* Fix a crash in comparing with float (and maybe other crashes). * They are now never equal to strings and non-integer numbers. * Comparison with a large number no longer raises OverflowError. * Arbitrary exceptions no longer silenced in constructors and comparisons. * TypeError raised in the constructor contains now the name of the type. * Accept only ChannelID and int-like objects in channel functions. * Accept only InterpreterId, int-like objects and str in the InterpreterId constructor. * Accept int-like objects, not just int in interpreter related functions. (cherry picked from commit bf16991)
1 parent d322abb commit f37a983

File tree

5 files changed

+153
-177
lines changed

5 files changed

+153
-177
lines changed

Include/cpython/interpreteridobject.h

-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ PyAPI_FUNC(PyObject *) _PyInterpreterID_New(int64_t);
1414
PyAPI_FUNC(PyObject *) _PyInterpreterState_GetIDObject(PyInterpreterState *);
1515
PyAPI_FUNC(PyInterpreterState *) _PyInterpreterID_LookUp(PyObject *);
1616

17-
PyAPI_FUNC(int64_t) _Py_CoerceID(PyObject *);
18-
1917
#ifdef __cplusplus
2018
}
2119
#endif

Lib/test/test__xxsubinterpreters.py

+38-39
Original file line numberDiff line numberDiff line change
@@ -526,30 +526,23 @@ def test_with_int(self):
526526
self.assertEqual(int(id), 10)
527527

528528
def test_coerce_id(self):
529-
id = interpreters.InterpreterID('10', force=True)
530-
self.assertEqual(int(id), 10)
531-
532-
id = interpreters.InterpreterID(10.0, force=True)
533-
self.assertEqual(int(id), 10)
534-
535529
class Int(str):
536-
def __init__(self, value):
537-
self._value = value
538-
def __int__(self):
539-
return self._value
530+
def __index__(self):
531+
return 10
540532

541-
id = interpreters.InterpreterID(Int(10), force=True)
542-
self.assertEqual(int(id), 10)
533+
for id in ('10', '1_0', Int()):
534+
with self.subTest(id=id):
535+
id = interpreters.InterpreterID(id, force=True)
536+
self.assertEqual(int(id), 10)
543537

544538
def test_bad_id(self):
545-
for id in [-1, 'spam']:
546-
with self.subTest(id):
547-
with self.assertRaises(ValueError):
548-
interpreters.InterpreterID(id)
549-
with self.assertRaises(OverflowError):
550-
interpreters.InterpreterID(2**64)
551-
with self.assertRaises(TypeError):
552-
interpreters.InterpreterID(object())
539+
self.assertRaises(TypeError, interpreters.InterpreterID, object())
540+
self.assertRaises(TypeError, interpreters.InterpreterID, 10.0)
541+
self.assertRaises(TypeError, interpreters.InterpreterID, b'10')
542+
self.assertRaises(ValueError, interpreters.InterpreterID, -1)
543+
self.assertRaises(ValueError, interpreters.InterpreterID, '-1')
544+
self.assertRaises(ValueError, interpreters.InterpreterID, 'spam')
545+
self.assertRaises(OverflowError, interpreters.InterpreterID, 2**64)
553546

554547
def test_does_not_exist(self):
555548
id = interpreters.channel_create()
@@ -572,6 +565,14 @@ def test_equality(self):
572565
self.assertTrue(id1 == id1)
573566
self.assertTrue(id1 == id2)
574567
self.assertTrue(id1 == int(id1))
568+
self.assertTrue(int(id1) == id1)
569+
self.assertTrue(id1 == float(int(id1)))
570+
self.assertTrue(float(int(id1)) == id1)
571+
self.assertFalse(id1 == float(int(id1)) + 0.1)
572+
self.assertFalse(id1 == str(int(id1)))
573+
self.assertFalse(id1 == 2**1000)
574+
self.assertFalse(id1 == float('inf'))
575+
self.assertFalse(id1 == 'spam')
575576
self.assertFalse(id1 == id3)
576577

577578
self.assertFalse(id1 != id1)
@@ -1105,30 +1106,20 @@ def test_with_kwargs(self):
11051106
self.assertEqual(cid.end, 'both')
11061107

11071108
def test_coerce_id(self):
1108-
cid = interpreters._channel_id('10', force=True)
1109-
self.assertEqual(int(cid), 10)
1110-
1111-
cid = interpreters._channel_id(10.0, force=True)
1112-
self.assertEqual(int(cid), 10)
1113-
11141109
class Int(str):
1115-
def __init__(self, value):
1116-
self._value = value
1117-
def __int__(self):
1118-
return self._value
1110+
def __index__(self):
1111+
return 10
11191112

1120-
cid = interpreters._channel_id(Int(10), force=True)
1113+
cid = interpreters._channel_id(Int(), force=True)
11211114
self.assertEqual(int(cid), 10)
11221115

11231116
def test_bad_id(self):
1124-
for cid in [-1, 'spam']:
1125-
with self.subTest(cid):
1126-
with self.assertRaises(ValueError):
1127-
interpreters._channel_id(cid)
1128-
with self.assertRaises(OverflowError):
1129-
interpreters._channel_id(2**64)
1130-
with self.assertRaises(TypeError):
1131-
interpreters._channel_id(object())
1117+
self.assertRaises(TypeError, interpreters._channel_id, object())
1118+
self.assertRaises(TypeError, interpreters._channel_id, 10.0)
1119+
self.assertRaises(TypeError, interpreters._channel_id, '10')
1120+
self.assertRaises(TypeError, interpreters._channel_id, b'10')
1121+
self.assertRaises(ValueError, interpreters._channel_id, -1)
1122+
self.assertRaises(OverflowError, interpreters._channel_id, 2**64)
11321123

11331124
def test_bad_kwargs(self):
11341125
with self.assertRaises(ValueError):
@@ -1164,6 +1155,14 @@ def test_equality(self):
11641155
self.assertTrue(cid1 == cid1)
11651156
self.assertTrue(cid1 == cid2)
11661157
self.assertTrue(cid1 == int(cid1))
1158+
self.assertTrue(int(cid1) == cid1)
1159+
self.assertTrue(cid1 == float(int(cid1)))
1160+
self.assertTrue(float(int(cid1)) == cid1)
1161+
self.assertFalse(cid1 == float(int(cid1)) + 0.1)
1162+
self.assertFalse(cid1 == str(int(cid1)))
1163+
self.assertFalse(cid1 == 2**1000)
1164+
self.assertFalse(cid1 == float('inf'))
1165+
self.assertFalse(cid1 == 'spam')
11671166
self.assertFalse(cid1 == cid3)
11681167

11691168
self.assertFalse(cid1 != cid1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed comparing and creating of InterpreterID and ChannelID.

Modules/_xxsubinterpretersmodule.c

+65-85
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,34 @@ typedef struct channelid {
14051405
_channels *channels;
14061406
} channelid;
14071407

1408+
static int
1409+
channel_id_converter(PyObject *arg, void *ptr)
1410+
{
1411+
int64_t cid;
1412+
if (PyObject_TypeCheck(arg, &ChannelIDtype)) {
1413+
cid = ((channelid *)arg)->id;
1414+
}
1415+
else if (PyIndex_Check(arg)) {
1416+
cid = PyLong_AsLongLong(arg);
1417+
if (cid == -1 && PyErr_Occurred()) {
1418+
return 0;
1419+
}
1420+
if (cid < 0) {
1421+
PyErr_Format(PyExc_ValueError,
1422+
"channel ID must be a non-negative int, got %R", arg);
1423+
return 0;
1424+
}
1425+
}
1426+
else {
1427+
PyErr_Format(PyExc_TypeError,
1428+
"channel ID must be an int, got %.100s",
1429+
arg->ob_type->tp_name);
1430+
return 0;
1431+
}
1432+
*(int64_t *)ptr = cid;
1433+
return 1;
1434+
}
1435+
14081436
static channelid *
14091437
newchannelid(PyTypeObject *cls, int64_t cid, int end, _channels *channels,
14101438
int force, int resolve)
@@ -1437,28 +1465,16 @@ static PyObject *
14371465
channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds)
14381466
{
14391467
static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL};
1440-
PyObject *id;
1468+
int64_t cid;
14411469
int send = -1;
14421470
int recv = -1;
14431471
int force = 0;
14441472
int resolve = 0;
14451473
if (!PyArg_ParseTupleAndKeywords(args, kwds,
1446-
"O|$pppp:ChannelID.__new__", kwlist,
1447-
&id, &send, &recv, &force, &resolve))
1474+
"O&|$pppp:ChannelID.__new__", kwlist,
1475+
channel_id_converter, &cid, &send, &recv, &force, &resolve))
14481476
return NULL;
14491477

1450-
// Coerce and check the ID.
1451-
int64_t cid;
1452-
if (PyObject_TypeCheck(id, &ChannelIDtype)) {
1453-
cid = ((channelid *)id)->id;
1454-
}
1455-
else {
1456-
cid = _Py_CoerceID(id);
1457-
if (cid < 0) {
1458-
return NULL;
1459-
}
1460-
}
1461-
14621478
// Handle "send" and "recv".
14631479
if (send == 0 && recv == 0) {
14641480
PyErr_SetString(PyExc_ValueError,
@@ -1592,30 +1608,28 @@ channelid_richcompare(PyObject *self, PyObject *other, int op)
15921608
int equal;
15931609
if (PyObject_TypeCheck(other, &ChannelIDtype)) {
15941610
channelid *othercid = (channelid *)other;
1595-
if (cid->end != othercid->end) {
1596-
equal = 0;
1597-
}
1598-
else {
1599-
equal = (cid->id == othercid->id);
1600-
}
1611+
equal = (cid->end == othercid->end) && (cid->id == othercid->id);
16011612
}
1602-
else {
1603-
other = PyNumber_Long(other);
1604-
if (other == NULL) {
1605-
PyErr_Clear();
1606-
Py_RETURN_NOTIMPLEMENTED;
1607-
}
1608-
int64_t othercid = PyLong_AsLongLong(other);
1609-
Py_DECREF(other);
1610-
if (othercid == -1 && PyErr_Occurred() != NULL) {
1613+
else if (PyLong_Check(other)) {
1614+
/* Fast path */
1615+
int overflow;
1616+
long long othercid = PyLong_AsLongLongAndOverflow(other, &overflow);
1617+
if (othercid == -1 && PyErr_Occurred()) {
16111618
return NULL;
16121619
}
1613-
if (othercid < 0) {
1614-
equal = 0;
1615-
}
1616-
else {
1617-
equal = (cid->id == othercid);
1620+
equal = !overflow && (othercid >= 0) && (cid->id == othercid);
1621+
}
1622+
else if (PyNumber_Check(other)) {
1623+
PyObject *pyid = PyLong_FromLongLong(cid->id);
1624+
if (pyid == NULL) {
1625+
return NULL;
16181626
}
1627+
PyObject *res = PyObject_RichCompare(pyid, other, op);
1628+
Py_DECREF(pyid);
1629+
return res;
1630+
}
1631+
else {
1632+
Py_RETURN_NOTIMPLEMENTED;
16191633
}
16201634

16211635
if ((op == Py_EQ && equal) || (op == Py_NE && !equal)) {
@@ -1754,8 +1768,7 @@ static PyTypeObject ChannelIDtype = {
17541768
0, /* tp_getattro */
17551769
0, /* tp_setattro */
17561770
0, /* tp_as_buffer */
1757-
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
1758-
Py_TPFLAGS_LONG_SUBCLASS, /* tp_flags */
1771+
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
17591772
channelid_doc, /* tp_doc */
17601773
0, /* tp_traverse */
17611774
0, /* tp_clear */
@@ -2017,10 +2030,6 @@ interp_destroy(PyObject *self, PyObject *args, PyObject *kwds)
20172030
"O:destroy", kwlist, &id)) {
20182031
return NULL;
20192032
}
2020-
if (!PyLong_Check(id)) {
2021-
PyErr_SetString(PyExc_TypeError, "ID must be an int");
2022-
return NULL;
2023-
}
20242033

20252034
// Look up the interpreter.
20262035
PyInterpreterState *interp = _PyInterpreterID_LookUp(id);
@@ -2145,10 +2154,6 @@ interp_run_string(PyObject *self, PyObject *args, PyObject *kwds)
21452154
&id, &code, &shared)) {
21462155
return NULL;
21472156
}
2148-
if (!PyLong_Check(id)) {
2149-
PyErr_SetString(PyExc_TypeError, "first arg (ID) must be an int");
2150-
return NULL;
2151-
}
21522157

21532158
// Look up the interpreter.
21542159
PyInterpreterState *interp = _PyInterpreterID_LookUp(id);
@@ -2216,10 +2221,6 @@ interp_is_running(PyObject *self, PyObject *args, PyObject *kwds)
22162221
"O:is_running", kwlist, &id)) {
22172222
return NULL;
22182223
}
2219-
if (!PyLong_Check(id)) {
2220-
PyErr_SetString(PyExc_TypeError, "ID must be an int");
2221-
return NULL;
2222-
}
22232224

22242225
PyInterpreterState *interp = _PyInterpreterID_LookUp(id);
22252226
if (interp == NULL) {
@@ -2268,13 +2269,9 @@ static PyObject *
22682269
channel_destroy(PyObject *self, PyObject *args, PyObject *kwds)
22692270
{
22702271
static char *kwlist[] = {"cid", NULL};
2271-
PyObject *id;
2272-
if (!PyArg_ParseTupleAndKeywords(args, kwds,
2273-
"O:channel_destroy", kwlist, &id)) {
2274-
return NULL;
2275-
}
2276-
int64_t cid = _Py_CoerceID(id);
2277-
if (cid < 0) {
2272+
int64_t cid;
2273+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:channel_destroy", kwlist,
2274+
channel_id_converter, &cid)) {
22782275
return NULL;
22792276
}
22802277

@@ -2331,14 +2328,10 @@ static PyObject *
23312328
channel_send(PyObject *self, PyObject *args, PyObject *kwds)
23322329
{
23332330
static char *kwlist[] = {"cid", "obj", NULL};
2334-
PyObject *id;
2331+
int64_t cid;
23352332
PyObject *obj;
2336-
if (!PyArg_ParseTupleAndKeywords(args, kwds,
2337-
"OO:channel_send", kwlist, &id, &obj)) {
2338-
return NULL;
2339-
}
2340-
int64_t cid = _Py_CoerceID(id);
2341-
if (cid < 0) {
2333+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O:channel_send", kwlist,
2334+
channel_id_converter, &cid, &obj)) {
23422335
return NULL;
23432336
}
23442337

@@ -2357,13 +2350,9 @@ static PyObject *
23572350
channel_recv(PyObject *self, PyObject *args, PyObject *kwds)
23582351
{
23592352
static char *kwlist[] = {"cid", NULL};
2360-
PyObject *id;
2361-
if (!PyArg_ParseTupleAndKeywords(args, kwds,
2362-
"O:channel_recv", kwlist, &id)) {
2363-
return NULL;
2364-
}
2365-
int64_t cid = _Py_CoerceID(id);
2366-
if (cid < 0) {
2353+
int64_t cid;
2354+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:channel_recv", kwlist,
2355+
channel_id_converter, &cid)) {
23672356
return NULL;
23682357
}
23692358

@@ -2379,17 +2368,13 @@ static PyObject *
23792368
channel_close(PyObject *self, PyObject *args, PyObject *kwds)
23802369
{
23812370
static char *kwlist[] = {"cid", "send", "recv", "force", NULL};
2382-
PyObject *id;
2371+
int64_t cid;
23832372
int send = 0;
23842373
int recv = 0;
23852374
int force = 0;
23862375
if (!PyArg_ParseTupleAndKeywords(args, kwds,
2387-
"O|$ppp:channel_close", kwlist,
2388-
&id, &send, &recv, &force)) {
2389-
return NULL;
2390-
}
2391-
int64_t cid = _Py_CoerceID(id);
2392-
if (cid < 0) {
2376+
"O&|$ppp:channel_close", kwlist,
2377+
channel_id_converter, &cid, &send, &recv, &force)) {
23932378
return NULL;
23942379
}
23952380

@@ -2431,17 +2416,13 @@ channel_release(PyObject *self, PyObject *args, PyObject *kwds)
24312416
{
24322417
// Note that only the current interpreter is affected.
24332418
static char *kwlist[] = {"cid", "send", "recv", "force", NULL};
2434-
PyObject *id;
2419+
int64_t cid;
24352420
int send = 0;
24362421
int recv = 0;
24372422
int force = 0;
24382423
if (!PyArg_ParseTupleAndKeywords(args, kwds,
2439-
"O|$ppp:channel_release", kwlist,
2440-
&id, &send, &recv, &force)) {
2441-
return NULL;
2442-
}
2443-
int64_t cid = _Py_CoerceID(id);
2444-
if (cid < 0) {
2424+
"O&|$ppp:channel_release", kwlist,
2425+
channel_id_converter, &cid, &send, &recv, &force)) {
24452426
return NULL;
24462427
}
24472428
if (send == 0 && recv == 0) {
@@ -2538,7 +2519,6 @@ PyInit__xxsubinterpreters(void)
25382519
}
25392520

25402521
/* Initialize types */
2541-
ChannelIDtype.tp_base = &PyLong_Type;
25422522
if (PyType_Ready(&ChannelIDtype) != 0) {
25432523
return NULL;
25442524
}

0 commit comments

Comments
 (0)