Skip to content

Commit 574daab

Browse files
author
Diptorup Deb
authored
Merge pull request #1454 from IntelPython/correct-refcount-handling-fixing-memory-leak
Fixed ref-counting of Python object temporaries in unboxing code
2 parents e36c979 + 5dcf8af commit 574daab

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -744,14 +744,18 @@ static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj)
744744
DPEXRT_DEBUG(
745745
drt_debug_print("DPEXRT-DEBUG: usm array was passed directly\n"));
746746
arrayobj = obj;
747+
Py_INCREF(arrayobj);
747748
}
748749
else if (PyObject_HasAttrString(obj, "_array_obj")) {
750+
// PyObject_GetAttrString gives reference
749751
arrayobj = PyObject_GetAttrString(obj, "_array_obj");
750752

751753
if (!arrayobj)
752754
return NULL;
753-
if (!PyObject_TypeCheck(arrayobj, &PyUSMArrayType))
755+
if (!PyObject_TypeCheck(arrayobj, &PyUSMArrayType)) {
756+
Py_DECREF(arrayobj);
754757
return NULL;
758+
}
755759
}
756760

757761
struct PyUSMArrayObject *pyusmarrayobj =
@@ -803,17 +807,13 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
803807
PyGILState_STATE gstate;
804808
npy_intp itemsize = 0;
805809

806-
// Increment the ref count on obj to prevent CPython from garbage
807-
// collecting the array.
808-
// TODO: add extra description why do we need this
809-
Py_IncRef(obj);
810-
811810
DPEXRT_DEBUG(drt_debug_print(
812811
"DPEXRT-DEBUG: In DPEXRT_sycl_usm_ndarray_from_python at %s, line %d\n",
813812
__FILE__, __LINE__));
814813

815814
// Check if the PyObject obj has an _array_obj attribute that is of
816815
// dpctl.tensor.usm_ndarray type.
816+
// arrayobj is a new reference, reference of obj is borrowed
817817
if (!(arrayobj = PyUSMNdArray_ARRAYOBJ(obj))) {
818818
DPEXRT_DEBUG(drt_debug_print(
819819
"DPEXRT-ERROR: PyUSMNdArray_ARRAYOBJ check failed at %s, line %d\n",
@@ -832,6 +832,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
832832
data = (void *)UsmNDArray_GetData(arrayobj);
833833
nitems = product_of_shape(shape, ndim);
834834
itemsize = (npy_intp)UsmNDArray_GetElementSize(arrayobj);
835+
835836
if (!(qref = UsmNDArray_GetQueueRef(arrayobj))) {
836837
DPEXRT_DEBUG(drt_debug_print(
837838
"DPEXRT-ERROR: UsmNDArray_GetQueueRef returned NULL at "
@@ -850,6 +851,9 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
850851
goto error;
851852
}
852853

854+
Py_XDECREF(arrayobj);
855+
Py_IncRef(obj);
856+
853857
arystruct->data = data;
854858
arystruct->sycl_queue = qref;
855859
arystruct->nitems = nitems;
@@ -906,7 +910,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
906910
__FILE__, __LINE__));
907911
gstate = PyGILState_Ensure();
908912
// decref the python object
909-
Py_DECREF(obj);
913+
Py_XDECREF((PyObject *)arrayobj);
910914
// release the GIL
911915
PyGILState_Release(gstate);
912916

@@ -938,26 +942,31 @@ static PyObject *box_from_arystruct_parent(usmarystruct_t *arystruct,
938942
drt_debug_print("DPEXRT-DEBUG: In box_from_arystruct_parent.\n"));
939943

940944
if (!(arrayobj = PyUSMNdArray_ARRAYOBJ(arystruct->parent))) {
945+
Py_XDECREF(arrayobj);
941946
DPEXRT_DEBUG(
942947
drt_debug_print("DPEXRT-DEBUG: Arrayobj cannot be boxed from "
943948
"parent as parent pointer is NULL.\n"));
944949
return NULL;
945950
}
946951

947952
if ((void *)UsmNDArray_GetData(arrayobj) != arystruct->data) {
953+
Py_XDECREF(arrayobj);
948954
DPEXRT_DEBUG(drt_debug_print(
949955
"DPEXRT-DEBUG: Arrayobj cannot be boxed "
950956
"from parent as data pointer in the arystruct is not the same as "
951957
"the data pointer in the parent object.\n"));
952958
return NULL;
953959
}
954960

955-
if (UsmNDArray_GetNDim(arrayobj) != ndim)
961+
if (UsmNDArray_GetNDim(arrayobj) != ndim) {
962+
Py_XDECREF(arrayobj);
956963
return NULL;
964+
}
957965

958966
p = arystruct->shape_and_strides;
959967
shape = UsmNDArray_GetShape(arrayobj);
960968
strides = UsmNDArray_GetStrides(arrayobj);
969+
Py_XDECREF(arrayobj);
961970

962971
// Ensure the shape of the array to be boxed matches the shape of the
963972
// original parent.

0 commit comments

Comments
 (0)