Skip to content

Commit 3796a4c

Browse files
committed
bpo-16575: Add checks for unions passed by value to functions.
1 parent 8e7bb99 commit 3796a4c

File tree

5 files changed

+114
-0
lines changed

5 files changed

+114
-0
lines changed

Lib/ctypes/test/test_structures.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,53 @@ class Test3(Structure):
525525
self.assertEqual(s.data[0], 3.14159)
526526
self.assertEqual(s.data[1], 2.71828)
527527

528+
def test_union_by_value(self):
529+
# See bpo-16575
530+
531+
# These should mirror the structures in Modules/_ctypes/_ctypes_test.c
532+
533+
class Nested1(Structure):
534+
_fields = [
535+
('an_int', c_int),
536+
('another_int', c_int),
537+
]
538+
539+
class Test4(Union):
540+
_fields_ = [
541+
('a_long', c_long),
542+
('a_struct', Nested1),
543+
]
544+
545+
class Nested2(Structure):
546+
_fields = [
547+
('an_int', c_int),
548+
('a_union', Test4),
549+
]
550+
551+
class Test5(Structure):
552+
_fields = [
553+
('an_int', c_int),
554+
('nested', Nested2),
555+
]
556+
557+
test4 = Test4()
558+
dll = CDLL(_ctypes_test.__file__)
559+
with self.assertRaises(TypeError) as ctx:
560+
func = dll._testfunc_union_by_value1
561+
func.restype = c_long
562+
func.argtypes = (Test4,)
563+
result = func(test4)
564+
self.assertEqual(ctx.exception.args[0], 'item 1 in _argtypes_ passes '
565+
'a union by value, which is unsupported.')
566+
test5 = Test5()
567+
with self.assertRaises(TypeError) as ctx:
568+
func = dll._testfunc_union_by_value2
569+
func.restype = c_long
570+
func.argtypes = (Test5,)
571+
result = func(test5)
572+
self.assertEqual(ctx.exception.args[0], 'item 1 in _argtypes_ passes '
573+
'a union by value, which is unsupported.')
574+
528575
class PointerMemberTestCase(unittest.TestCase):
529576

530577
def test(self):

Modules/_ctypes/_ctypes.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,6 +2383,25 @@ converters_from_argtypes(PyObject *ob)
23832383
for (i = 0; i < nArgs; ++i) {
23842384
PyObject *cnv;
23852385
PyObject *tp = PyTuple_GET_ITEM(ob, i);
2386+
StgDictObject *stgdict = PyType_stgdict(tp);
2387+
2388+
if (stgdict != NULL) {
2389+
if (stgdict->flags & TYPEFLAG_HASUNION) {
2390+
Py_DECREF(converters);
2391+
Py_DECREF(ob);
2392+
if (!PyErr_Occurred()) {
2393+
PyErr_Format(PyExc_TypeError,
2394+
"item %zd in _argtypes_ passes a union by "
2395+
"value, which is unsupported.",
2396+
i + 1);
2397+
}
2398+
return NULL;
2399+
}
2400+
if (stgdict->flags & TYPEFLAG_HASBITFIELD) {
2401+
printf("found stgdict with bitfield\n");
2402+
}
2403+
}
2404+
23862405
if (_PyObject_LookupAttrId(tp, &PyId_from_param, &cnv) <= 0) {
23872406
Py_DECREF(converters);
23882407
Py_DECREF(ob);

Modules/_ctypes/_ctypes_test.c

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,44 @@ _testfunc_array_in_struct2(Test3 in)
114114
return result;
115115
}
116116

117+
typedef union {
118+
long a_long;
119+
struct {
120+
int an_int;
121+
int another_int;
122+
} a_struct;
123+
} Test4;
124+
125+
typedef struct {
126+
int an_int;
127+
struct {
128+
int an_int;
129+
Test4 a_union;
130+
} nested;
131+
} Test5;
132+
133+
EXPORT(long)
134+
_testfunc_union_by_value1(Test4 in) {
135+
long result = in.a_long + in.a_struct.an_int + in.a_struct.another_int;
136+
137+
/* As the union/struct are passed by value, changes to them shouldn't be
138+
* reflected in the caller.
139+
*/
140+
memset(&in, 0, sizeof(in));
141+
return result;
142+
}
143+
144+
EXPORT(long)
145+
_testfunc_union_by_value2(Test5 in) {
146+
long result = in.an_int + in.nested.an_int;
147+
148+
/* As the union/struct are passed by value, changes to them shouldn't be
149+
* reflected in the caller.
150+
*/
151+
memset(&in, 0, sizeof(in));
152+
return result;
153+
}
154+
117155
EXPORT(void)testfunc_array(int values[4])
118156
{
119157
printf("testfunc_array %d %d %d %d\n",

Modules/_ctypes/ctypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ PyObject *_ctypes_callproc(PPROC pProc,
288288

289289
#define TYPEFLAG_ISPOINTER 0x100
290290
#define TYPEFLAG_HASPOINTER 0x200
291+
#define TYPEFLAG_HASUNION 0x400
292+
#define TYPEFLAG_HASBITFIELD 0x800
291293

292294
#define DICTFLAG_FINAL 0x1000
293295

Modules/_ctypes/stgdict.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,13 @@ PyCStructUnionType_update_stgdict(PyObject *type, PyObject *fields, int isStruct
440440
PyMem_Free(stgdict->ffi_type_pointer.elements);
441441

442442
basedict = PyType_stgdict((PyObject *)((PyTypeObject *)type)->tp_base);
443+
if (basedict) {
444+
stgdict->flags |= (basedict->flags &
445+
(TYPEFLAG_HASUNION | TYPEFLAG_HASBITFIELD));
446+
}
447+
if (!isStruct) {
448+
stgdict->flags |= TYPEFLAG_HASUNION;
449+
}
443450
if (basedict && !use_broken_old_ctypes_semantics) {
444451
size = offset = basedict->size;
445452
align = basedict->align;
@@ -517,6 +524,7 @@ PyCStructUnionType_update_stgdict(PyObject *type, PyObject *fields, int isStruct
517524
stgdict->flags |= TYPEFLAG_HASPOINTER;
518525
dict->flags |= DICTFLAG_FINAL; /* mark field type final */
519526
if (PyTuple_Size(pair) == 3) { /* bits specified */
527+
stgdict->flags |= TYPEFLAG_HASBITFIELD;
520528
switch(dict->ffi_type_pointer.type) {
521529
case FFI_TYPE_UINT8:
522530
case FFI_TYPE_UINT16:

0 commit comments

Comments
 (0)