Skip to content

Move callback wrappers to Python layer #544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 4 additions & 110 deletions cuda_bindings/cuda/bindings/cyruntime.pyx.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
# is strictly prohibited.
#
# This code was automatically generated with version 12.8.0. Do not modify it directly.
import cython
from libcpp.map cimport map
from libc.stdlib cimport malloc, free
cimport cuda.bindings._bindings.cyruntime as cyruntime
cimport cuda.bindings._lib.cyruntime.cyruntime as custom_cyruntime

Expand Down Expand Up @@ -111,47 +108,14 @@ cdef cudaError_t cudaDeviceFlushGPUDirectRDMAWrites(cudaFlushGPUDirectRDMAWrites

{{if 'cudaDeviceRegisterAsyncNotification' in found_functions}}

ctypedef struct cudaAsyncCallbackData_st:
cudaAsyncCallback callback
void *userData

ctypedef cudaAsyncCallbackData_st cudaAsyncCallbackData

@cython.show_performance_hints(False)
cdef void cudaAsyncNotificationCallbackWrapper(cudaAsyncNotificationInfo_t *info, void *data, cudaAsyncCallbackHandle_t handle) nogil:
cdef cudaAsyncCallbackData *cbData = <cudaAsyncCallbackData *>data
with gil:
cbData.callback(info, cbData.userData, handle)


cdef cudaError_t cudaDeviceRegisterAsyncNotification(int device, cudaAsyncCallback callbackFunc, void* userData, cudaAsyncCallbackHandle_t* callback) except ?cudaErrorCallRequiresNewerDriver nogil:
cdef cudaAsyncCallbackData *cbData = NULL
cdef cudaError_t err = cudaSuccess
cbData = <cudaAsyncCallbackData *>malloc(sizeof(cbData[0]))

if cbData == NULL:
return cudaErrorMemoryAllocation

cbData.callback = callbackFunc
cbData.userData = userData
err = cyruntime._cudaDeviceRegisterAsyncNotification(device, <cudaAsyncCallback>cudaAsyncNotificationCallbackWrapper, <void *>cbData, callback)
if err != cudaSuccess:
free(cbData)
return err

m_global._asyncCallbackDataMap[callback[0]] = cbData
return err
return cyruntime._cudaDeviceRegisterAsyncNotification(device, callbackFunc, userData, callback)
{{endif}}

{{if 'cudaDeviceUnregisterAsyncNotification' in found_functions}}

cdef cudaError_t cudaDeviceUnregisterAsyncNotification(int device, cudaAsyncCallbackHandle_t callback) except ?cudaErrorCallRequiresNewerDriver nogil:
cdef cudaError_t err = cudaSuccess
err = cyruntime._cudaDeviceUnregisterAsyncNotification(device, callback)
if err == cudaSuccess:
free(m_global._asyncCallbackDataMap[callback])
m_global._asyncCallbackDataMap.erase(callback)
return err
return cyruntime._cudaDeviceUnregisterAsyncNotification(device, callback)
{{endif}}

{{if 'cudaDeviceGetSharedMemConfig' in found_functions}}
Expand Down Expand Up @@ -354,35 +318,8 @@ cdef cudaError_t cudaStreamWaitEvent(cudaStream_t stream, cudaEvent_t event, uns

{{if 'cudaStreamAddCallback' in found_functions}}

ctypedef struct cudaStreamCallbackData_st:
cudaStreamCallback_t callback
void *userData

ctypedef cudaStreamCallbackData_st cudaStreamCallbackData

@cython.show_performance_hints(False)
cdef void cudaStreamRtCallbackWrapper(cudaStream_t stream, cudaError_t status, void *data) nogil:
cdef cudaStreamCallbackData *cbData = <cudaStreamCallbackData *>data
with gil:
cbData.callback(stream, status, cbData.userData)
free(cbData)


cdef cudaError_t cudaStreamAddCallback(cudaStream_t stream, cudaStreamCallback_t callback, void* userData, unsigned int flags) except ?cudaErrorCallRequiresNewerDriver nogil:
cdef cudaStreamCallbackData *cbData = NULL
cdef cudaError_t err = cudaSuccess
cbData = <cudaStreamCallbackData *>malloc(sizeof(cbData[0]))

if cbData == NULL:
return cudaErrorMemoryAllocation

cbData.callback = callback
cbData.userData = userData
err = cyruntime._cudaStreamAddCallback(stream, <cudaStreamCallback_t>cudaStreamRtCallbackWrapper, <void *>cbData, flags)
if err != cudaSuccess:
free(cbData)
return err
return err
return cyruntime._cudaStreamAddCallback(stream, callback, userData, flags)
{{endif}}

{{if 'cudaStreamSynchronize' in found_functions}}
Expand Down Expand Up @@ -579,35 +516,8 @@ cdef cudaError_t cudaFuncSetAttribute(const void* func, cudaFuncAttribute attr,

{{if 'cudaLaunchHostFunc' in found_functions}}

ctypedef struct cudaStreamHostCallbackData_st:
cudaHostFn_t callback
void *userData

ctypedef cudaStreamHostCallbackData_st cudaStreamHostCallbackData

@cython.show_performance_hints(False)
cdef void cudaStreamRtHostCallbackWrapper(void *data) nogil:
cdef cudaStreamHostCallbackData *cbData = <cudaStreamHostCallbackData *>data
with gil:
cbData.callback(cbData.userData)
free(cbData)


cdef cudaError_t cudaLaunchHostFunc(cudaStream_t stream, cudaHostFn_t fn, void* userData) except ?cudaErrorCallRequiresNewerDriver nogil:
cdef cudaStreamHostCallbackData *cbData = NULL
cdef cudaError_t err = cudaSuccess
cbData = <cudaStreamHostCallbackData *>malloc(sizeof(cbData[0]))

if cbData == NULL:
return cudaErrorMemoryAllocation

cbData.callback = fn
cbData.userData = userData
err = cyruntime._cudaLaunchHostFunc(stream, <cudaHostFn_t>cudaStreamRtHostCallbackWrapper, <void *>cbData)
if err != cudaSuccess:
free(cbData)
return err
return err
return cyruntime._cudaLaunchHostFunc(stream, fn, userData)
{{endif}}

{{if 'cudaFuncSetSharedMemConfig' in found_functions}}
Expand Down Expand Up @@ -1966,22 +1876,6 @@ cdef cudaError_t cudaGraphicsVDPAURegisterOutputSurface(cudaGraphicsResource** r
return custom_cyruntime._cudaGraphicsVDPAURegisterOutputSurface(resource, vdpSurface, flags)
{{endif}}


cdef class cudaBindingsRuntimeGlobal:
{{if 'cudaDeviceRegisterAsyncNotification' in found_functions}}
cdef map[cudaAsyncCallbackHandle_t, cudaAsyncCallbackData*] _asyncCallbackDataMap
{{endif}}

def __dealloc__(self):
pass
{{if 'cudaDeviceRegisterAsyncNotification' in found_functions}}
for item in self._asyncCallbackDataMap:
free(item.second)
self._asyncCallbackDataMap.clear()
{{endif}}

cdef cudaBindingsRuntimeGlobal m_global = cudaBindingsRuntimeGlobal()

{{if True}}

{{if 'Windows' != platform.system()}}
Expand Down
Loading
Loading