Skip to content

gh-108240: Add _PyCapsule_SetTraverse() internal function #108339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Include/pycapsule.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ PyAPI_FUNC(int) PyCapsule_SetName(PyObject *capsule, const char *name);

PyAPI_FUNC(int) PyCapsule_SetContext(PyObject *capsule, void *context);

#ifdef Py_BUILD_CORE
PyAPI_FUNC(int) _PyCapsule_SetTraverse(PyObject *op, traverseproc traverse_func, inquiry clear_func);
#endif

PyAPI_FUNC(void *) PyCapsule_Import(
const char *name, /* UTF-8 encoded string */
int no_block);
Expand Down
37 changes: 31 additions & 6 deletions Modules/socketmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -7314,20 +7314,39 @@ os_init(void)
}
#endif

static int
sock_capi_traverse(PyObject *capsule, visitproc visit, void *arg)
{
PySocketModule_APIObject *capi = PyCapsule_GetPointer(capsule, PySocket_CAPSULE_NAME);
assert(capi != NULL);
Py_VISIT(capi->Sock_Type);
return 0;
}

static int
sock_capi_clear(PyObject *capsule)
{
PySocketModule_APIObject *capi = PyCapsule_GetPointer(capsule, PySocket_CAPSULE_NAME);
assert(capi != NULL);
Py_CLEAR(capi->Sock_Type);
return 0;
}

static void
sock_free_api(PySocketModule_APIObject *capi)
sock_capi_free(PySocketModule_APIObject *capi)
{
Py_DECREF(capi->Sock_Type);
Py_XDECREF(capi->Sock_Type); // sock_capi_free() can clear it
Py_DECREF(capi->error);
Py_DECREF(capi->timeout_error);
PyMem_Free(capi);
}

static void
sock_destroy_api(PyObject *capsule)
sock_capi_destroy(PyObject *capsule)
{
void *capi = PyCapsule_GetPointer(capsule, PySocket_CAPSULE_NAME);
sock_free_api(capi);
assert(capi != NULL);
sock_capi_free(capi);
}

static PySocketModule_APIObject *
Expand Down Expand Up @@ -7432,11 +7451,17 @@ socket_exec(PyObject *m)
}
PyObject *capsule = PyCapsule_New(capi,
PySocket_CAPSULE_NAME,
sock_destroy_api);
sock_capi_destroy);
if (capsule == NULL) {
sock_free_api(capi);
sock_capi_free(capi);
goto error;
}
if (_PyCapsule_SetTraverse(capsule,
sock_capi_traverse, sock_capi_clear) < 0) {
sock_capi_free(capi);
goto error;
}

if (PyModule_Add(m, PySocket_CAPI_NAME, capsule) < 0) {
goto error;
}
Expand Down
166 changes: 100 additions & 66 deletions Objects/capsule.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,28 @@ typedef struct {
const char *name;
void *context;
PyCapsule_Destructor destructor;
traverseproc traverse_func;
inquiry clear_func;
} PyCapsule;



static int
_is_legal_capsule(PyCapsule *capsule, const char *invalid_capsule)
_is_legal_capsule(PyObject *op, const char *invalid_capsule)
{
if (!capsule || !PyCapsule_CheckExact(capsule) || capsule->pointer == NULL) {
PyErr_SetString(PyExc_ValueError, invalid_capsule);
return 0;
if (!op || !PyCapsule_CheckExact(op)) {
goto error;
}
PyCapsule *capsule = (PyCapsule *)op;

if (capsule->pointer == NULL) {
goto error;
}
return 1;

error:
PyErr_SetString(PyExc_ValueError, invalid_capsule);
return 0;
}

#define is_legal_capsule(capsule, name) \
Expand Down Expand Up @@ -50,7 +60,7 @@ PyCapsule_New(void *pointer, const char *name, PyCapsule_Destructor destructor)
return NULL;
}

capsule = PyObject_New(PyCapsule, &PyCapsule_Type);
capsule = PyObject_GC_New(PyCapsule, &PyCapsule_Type);
if (capsule == NULL) {
return NULL;
}
Expand All @@ -59,15 +69,18 @@ PyCapsule_New(void *pointer, const char *name, PyCapsule_Destructor destructor)
capsule->name = name;
capsule->context = NULL;
capsule->destructor = destructor;
capsule->traverse_func = NULL;
capsule->clear_func = NULL;
// Only track the capsule if _PyCapsule_SetTraverse() is called

return (PyObject *)capsule;
}


int
PyCapsule_IsValid(PyObject *o, const char *name)
PyCapsule_IsValid(PyObject *op, const char *name)
{
PyCapsule *capsule = (PyCapsule *)o;
PyCapsule *capsule = (PyCapsule *)op;

return (capsule != NULL &&
PyCapsule_CheckExact(capsule) &&
Expand All @@ -77,13 +90,12 @@ PyCapsule_IsValid(PyObject *o, const char *name)


void *
PyCapsule_GetPointer(PyObject *o, const char *name)
PyCapsule_GetPointer(PyObject *op, const char *name)
{
PyCapsule *capsule = (PyCapsule *)o;

if (!is_legal_capsule(capsule, "PyCapsule_GetPointer")) {
if (!is_legal_capsule(op, "PyCapsule_GetPointer")) {
return NULL;
}
PyCapsule *capsule = (PyCapsule *)op;

if (!name_matches(name, capsule->name)) {
PyErr_SetString(PyExc_ValueError, "PyCapsule_GetPointer called with incorrect name");
Expand All @@ -95,52 +107,48 @@ PyCapsule_GetPointer(PyObject *o, const char *name)


const char *
PyCapsule_GetName(PyObject *o)
PyCapsule_GetName(PyObject *op)
{
PyCapsule *capsule = (PyCapsule *)o;

if (!is_legal_capsule(capsule, "PyCapsule_GetName")) {
if (!is_legal_capsule(op, "PyCapsule_GetName")) {
return NULL;
}
PyCapsule *capsule = (PyCapsule *)op;
return capsule->name;
}


PyCapsule_Destructor
PyCapsule_GetDestructor(PyObject *o)
PyCapsule_GetDestructor(PyObject *op)
{
PyCapsule *capsule = (PyCapsule *)o;

if (!is_legal_capsule(capsule, "PyCapsule_GetDestructor")) {
if (!is_legal_capsule(op, "PyCapsule_GetDestructor")) {
return NULL;
}
PyCapsule *capsule = (PyCapsule *)op;
return capsule->destructor;
}


void *
PyCapsule_GetContext(PyObject *o)
PyCapsule_GetContext(PyObject *op)
{
PyCapsule *capsule = (PyCapsule *)o;

if (!is_legal_capsule(capsule, "PyCapsule_GetContext")) {
if (!is_legal_capsule(op, "PyCapsule_GetContext")) {
return NULL;
}
PyCapsule *capsule = (PyCapsule *)op;
return capsule->context;
}


int
PyCapsule_SetPointer(PyObject *o, void *pointer)
PyCapsule_SetPointer(PyObject *op, void *pointer)
{
PyCapsule *capsule = (PyCapsule *)o;

if (!pointer) {
PyErr_SetString(PyExc_ValueError, "PyCapsule_SetPointer called with null pointer");
if (!is_legal_capsule(op, "PyCapsule_SetPointer")) {
return -1;
}
PyCapsule *capsule = (PyCapsule *)op;

if (!is_legal_capsule(capsule, "PyCapsule_SetPointer")) {
if (!pointer) {
PyErr_SetString(PyExc_ValueError, "PyCapsule_SetPointer called with null pointer");
return -1;
}

Expand All @@ -150,47 +158,62 @@ PyCapsule_SetPointer(PyObject *o, void *pointer)


int
PyCapsule_SetName(PyObject *o, const char *name)
PyCapsule_SetName(PyObject *op, const char *name)
{
PyCapsule *capsule = (PyCapsule *)o;

if (!is_legal_capsule(capsule, "PyCapsule_SetName")) {
if (!is_legal_capsule(op, "PyCapsule_SetName")) {
return -1;
}
PyCapsule *capsule = (PyCapsule *)op;

capsule->name = name;
return 0;
}


int
PyCapsule_SetDestructor(PyObject *o, PyCapsule_Destructor destructor)
PyCapsule_SetDestructor(PyObject *op, PyCapsule_Destructor destructor)
{
PyCapsule *capsule = (PyCapsule *)o;

if (!is_legal_capsule(capsule, "PyCapsule_SetDestructor")) {
if (!is_legal_capsule(op, "PyCapsule_SetDestructor")) {
return -1;
}
PyCapsule *capsule = (PyCapsule *)op;

capsule->destructor = destructor;
return 0;
}


int
PyCapsule_SetContext(PyObject *o, void *context)
PyCapsule_SetContext(PyObject *op, void *context)
{
PyCapsule *capsule = (PyCapsule *)o;

if (!is_legal_capsule(capsule, "PyCapsule_SetContext")) {
if (!is_legal_capsule(op, "PyCapsule_SetContext")) {
return -1;
}
PyCapsule *capsule = (PyCapsule *)op;

capsule->context = context;
return 0;
}


int
_PyCapsule_SetTraverse(PyObject *op, traverseproc traverse_func, inquiry clear_func)
{
if (!is_legal_capsule(op, "_PyCapsule_SetTraverse")) {
return -1;
}
PyCapsule *capsule = (PyCapsule *)op;

if (!PyObject_GC_IsTracked(op)) {
PyObject_GC_Track(op);
Comment on lines +207 to +208
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could have used the internal APIs here:

Suggested change
if (!PyObject_GC_IsTracked(op)) {
PyObject_GC_Track(op);
assert(_PyObject_IS_GC(op));
if (!_PyObject_GC_IS_TRACKED(op)) {
_PyObject_GC_TRACK(op);
}

OTOH, _PyCapsule_SetTraverse is probably not part of hot code, so I guess the public APIs are fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used _PyObject_GC_IS_TRACKED() in my second PR, but without assert(_PyObject_IS_GC(op)) which looks overkill.

}

capsule->traverse_func = traverse_func;
capsule->clear_func = clear_func;
return 0;
}


void *
PyCapsule_Import(const char *name, int no_block)
{
Expand Down Expand Up @@ -249,13 +272,14 @@ PyCapsule_Import(const char *name, int no_block)


static void
capsule_dealloc(PyObject *o)
capsule_dealloc(PyObject *op)
{
PyCapsule *capsule = (PyCapsule *)o;
PyCapsule *capsule = (PyCapsule *)op;
PyObject_GC_UnTrack(op);
if (capsule->destructor) {
capsule->destructor(o);
capsule->destructor(op);
}
PyObject_Free(o);
PyObject_GC_Del(op);
}


Expand All @@ -279,6 +303,29 @@ capsule_repr(PyObject *o)
}


static int
capsule_traverse(PyCapsule *capsule, visitproc visit, void *arg)
{
if (capsule->traverse_func) {
return capsule->traverse_func((PyObject*)capsule, visit, arg);
}
else {
return 0;
}
Comment on lines +312 to +314
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else is unneeded here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the callback can be NULL. I added a check in my second PR to reject NULL callbacks.

}


static int
capsule_clear(PyCapsule *capsule)
{
if (capsule->clear_func) {
return capsule->clear_func((PyObject*)capsule);
}
else {
return 0;
}
}


PyDoc_STRVAR(PyCapsule_Type__doc__,
"Capsule objects let you wrap a C \"void *\" pointer in a Python\n\
Expand All @@ -293,27 +340,14 @@ Python import mechanism to link to one another.\n\

PyTypeObject PyCapsule_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
"PyCapsule", /*tp_name*/
sizeof(PyCapsule), /*tp_basicsize*/
0, /*tp_itemsize*/
/* methods */
capsule_dealloc, /*tp_dealloc*/
0, /*tp_vectorcall_offset*/
0, /*tp_getattr*/
0, /*tp_setattr*/
0, /*tp_as_async*/
capsule_repr, /*tp_repr*/
0, /*tp_as_number*/
0, /*tp_as_sequence*/
0, /*tp_as_mapping*/
0, /*tp_hash*/
0, /*tp_call*/
0, /*tp_str*/
0, /*tp_getattro*/
0, /*tp_setattro*/
0, /*tp_as_buffer*/
0, /*tp_flags*/
PyCapsule_Type__doc__ /*tp_doc*/
.tp_name = "PyCapsule",
.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
.tp_basicsize = sizeof(PyCapsule),
.tp_dealloc = capsule_dealloc,
.tp_repr = capsule_repr,
.tp_doc = PyCapsule_Type__doc__,
.tp_traverse = (traverseproc)capsule_traverse,
.tp_clear = (inquiry)capsule_clear,
};