Skip to content

Commit 6d07877

Browse files
authored
Merge pull request #544 from vzhurba01/531-move-callback-wrappers
Move callback wrappers to Python layer
2 parents 73fa3c0 + f00d7bf commit 6d07877

File tree

4 files changed

+1719
-2408
lines changed

4 files changed

+1719
-2408
lines changed

cuda_bindings/cuda/bindings/cyruntime.pyx.in

Lines changed: 4 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
# is strictly prohibited.
88
#
99
# This code was automatically generated with version 12.8.0. Do not modify it directly.
10-
import cython
11-
from libcpp.map cimport map
12-
from libc.stdlib cimport malloc, free
1310
cimport cuda.bindings._bindings.cyruntime as cyruntime
1411
cimport cuda.bindings._lib.cyruntime.cyruntime as custom_cyruntime
1512

@@ -111,47 +108,14 @@ cdef cudaError_t cudaDeviceFlushGPUDirectRDMAWrites(cudaFlushGPUDirectRDMAWrites
111108

112109
{{if 'cudaDeviceRegisterAsyncNotification' in found_functions}}
113110

114-
ctypedef struct cudaAsyncCallbackData_st:
115-
cudaAsyncCallback callback
116-
void *userData
117-
118-
ctypedef cudaAsyncCallbackData_st cudaAsyncCallbackData
119-
120-
@cython.show_performance_hints(False)
121-
cdef void cudaAsyncNotificationCallbackWrapper(cudaAsyncNotificationInfo_t *info, void *data, cudaAsyncCallbackHandle_t handle) nogil:
122-
cdef cudaAsyncCallbackData *cbData = <cudaAsyncCallbackData *>data
123-
with gil:
124-
cbData.callback(info, cbData.userData, handle)
125-
126-
127111
cdef cudaError_t cudaDeviceRegisterAsyncNotification(int device, cudaAsyncCallback callbackFunc, void* userData, cudaAsyncCallbackHandle_t* callback) except ?cudaErrorCallRequiresNewerDriver nogil:
128-
cdef cudaAsyncCallbackData *cbData = NULL
129-
cdef cudaError_t err = cudaSuccess
130-
cbData = <cudaAsyncCallbackData *>malloc(sizeof(cbData[0]))
131-
132-
if cbData == NULL:
133-
return cudaErrorMemoryAllocation
134-
135-
cbData.callback = callbackFunc
136-
cbData.userData = userData
137-
err = cyruntime._cudaDeviceRegisterAsyncNotification(device, <cudaAsyncCallback>cudaAsyncNotificationCallbackWrapper, <void *>cbData, callback)
138-
if err != cudaSuccess:
139-
free(cbData)
140-
return err
141-
142-
m_global._asyncCallbackDataMap[callback[0]] = cbData
143-
return err
112+
return cyruntime._cudaDeviceRegisterAsyncNotification(device, callbackFunc, userData, callback)
144113
{{endif}}
145114

146115
{{if 'cudaDeviceUnregisterAsyncNotification' in found_functions}}
147116

148117
cdef cudaError_t cudaDeviceUnregisterAsyncNotification(int device, cudaAsyncCallbackHandle_t callback) except ?cudaErrorCallRequiresNewerDriver nogil:
149-
cdef cudaError_t err = cudaSuccess
150-
err = cyruntime._cudaDeviceUnregisterAsyncNotification(device, callback)
151-
if err == cudaSuccess:
152-
free(m_global._asyncCallbackDataMap[callback])
153-
m_global._asyncCallbackDataMap.erase(callback)
154-
return err
118+
return cyruntime._cudaDeviceUnregisterAsyncNotification(device, callback)
155119
{{endif}}
156120

157121
{{if 'cudaDeviceGetSharedMemConfig' in found_functions}}
@@ -354,35 +318,8 @@ cdef cudaError_t cudaStreamWaitEvent(cudaStream_t stream, cudaEvent_t event, uns
354318

355319
{{if 'cudaStreamAddCallback' in found_functions}}
356320

357-
ctypedef struct cudaStreamCallbackData_st:
358-
cudaStreamCallback_t callback
359-
void *userData
360-
361-
ctypedef cudaStreamCallbackData_st cudaStreamCallbackData
362-
363-
@cython.show_performance_hints(False)
364-
cdef void cudaStreamRtCallbackWrapper(cudaStream_t stream, cudaError_t status, void *data) nogil:
365-
cdef cudaStreamCallbackData *cbData = <cudaStreamCallbackData *>data
366-
with gil:
367-
cbData.callback(stream, status, cbData.userData)
368-
free(cbData)
369-
370-
371321
cdef cudaError_t cudaStreamAddCallback(cudaStream_t stream, cudaStreamCallback_t callback, void* userData, unsigned int flags) except ?cudaErrorCallRequiresNewerDriver nogil:
372-
cdef cudaStreamCallbackData *cbData = NULL
373-
cdef cudaError_t err = cudaSuccess
374-
cbData = <cudaStreamCallbackData *>malloc(sizeof(cbData[0]))
375-
376-
if cbData == NULL:
377-
return cudaErrorMemoryAllocation
378-
379-
cbData.callback = callback
380-
cbData.userData = userData
381-
err = cyruntime._cudaStreamAddCallback(stream, <cudaStreamCallback_t>cudaStreamRtCallbackWrapper, <void *>cbData, flags)
382-
if err != cudaSuccess:
383-
free(cbData)
384-
return err
385-
return err
322+
return cyruntime._cudaStreamAddCallback(stream, callback, userData, flags)
386323
{{endif}}
387324

388325
{{if 'cudaStreamSynchronize' in found_functions}}
@@ -579,35 +516,8 @@ cdef cudaError_t cudaFuncSetAttribute(const void* func, cudaFuncAttribute attr,
579516

580517
{{if 'cudaLaunchHostFunc' in found_functions}}
581518

582-
ctypedef struct cudaStreamHostCallbackData_st:
583-
cudaHostFn_t callback
584-
void *userData
585-
586-
ctypedef cudaStreamHostCallbackData_st cudaStreamHostCallbackData
587-
588-
@cython.show_performance_hints(False)
589-
cdef void cudaStreamRtHostCallbackWrapper(void *data) nogil:
590-
cdef cudaStreamHostCallbackData *cbData = <cudaStreamHostCallbackData *>data
591-
with gil:
592-
cbData.callback(cbData.userData)
593-
free(cbData)
594-
595-
596519
cdef cudaError_t cudaLaunchHostFunc(cudaStream_t stream, cudaHostFn_t fn, void* userData) except ?cudaErrorCallRequiresNewerDriver nogil:
597-
cdef cudaStreamHostCallbackData *cbData = NULL
598-
cdef cudaError_t err = cudaSuccess
599-
cbData = <cudaStreamHostCallbackData *>malloc(sizeof(cbData[0]))
600-
601-
if cbData == NULL:
602-
return cudaErrorMemoryAllocation
603-
604-
cbData.callback = fn
605-
cbData.userData = userData
606-
err = cyruntime._cudaLaunchHostFunc(stream, <cudaHostFn_t>cudaStreamRtHostCallbackWrapper, <void *>cbData)
607-
if err != cudaSuccess:
608-
free(cbData)
609-
return err
610-
return err
520+
return cyruntime._cudaLaunchHostFunc(stream, fn, userData)
611521
{{endif}}
612522

613523
{{if 'cudaFuncSetSharedMemConfig' in found_functions}}
@@ -1966,22 +1876,6 @@ cdef cudaError_t cudaGraphicsVDPAURegisterOutputSurface(cudaGraphicsResource** r
19661876
return custom_cyruntime._cudaGraphicsVDPAURegisterOutputSurface(resource, vdpSurface, flags)
19671877
{{endif}}
19681878

1969-
1970-
cdef class cudaBindingsRuntimeGlobal:
1971-
{{if 'cudaDeviceRegisterAsyncNotification' in found_functions}}
1972-
cdef map[cudaAsyncCallbackHandle_t, cudaAsyncCallbackData*] _asyncCallbackDataMap
1973-
{{endif}}
1974-
1975-
def __dealloc__(self):
1976-
pass
1977-
{{if 'cudaDeviceRegisterAsyncNotification' in found_functions}}
1978-
for item in self._asyncCallbackDataMap:
1979-
free(item.second)
1980-
self._asyncCallbackDataMap.clear()
1981-
{{endif}}
1982-
1983-
cdef cudaBindingsRuntimeGlobal m_global = cudaBindingsRuntimeGlobal()
1984-
19851879
{{if True}}
19861880

19871881
{{if 'Windows' != platform.system()}}

0 commit comments

Comments
 (0)