From 88e5c839f485668dbdd007a3b0cea0a3b57c2d34 Mon Sep 17 00:00:00 2001 From: Jorge Ortega Date: Wed, 2 Apr 2025 16:02:58 -0700 Subject: [PATCH] refactor(cust_raw): Parse macro function renames in cuda headers. Adds a custom bindgen callback that prevents function renames due to macro defines while still linking to the intend function after macro expansion. It does this by tracking macros when generating the bindings, so that it can change the name of the function back to what it was before the macro changed and link to the macro expanded function name. Doing so helps prevents breaking changes across CUDA versions when generating the bindings, and generates function bindings that match those used in Nvidia's CUDA documentation. misc.: optix-sys rebuilds if related optix environment variables change. misc.: unpin cc and jobserver from add example. --- crates/blastoff/src/context.rs | 14 +-- crates/blastoff/src/raw/level1.rs | 92 +++++++------- crates/blastoff/src/raw/level3.rs | 8 +- crates/cust/src/context/legacy.rs | 10 +- crates/cust/src/context/mod.rs | 8 +- crates/cust/src/device.rs | 2 +- crates/cust/src/event.rs | 6 +- crates/cust/src/graph.rs | 4 +- crates/cust/src/link.rs | 8 +- crates/cust/src/memory/array.rs | 20 ++- crates/cust/src/memory/device/device_box.rs | 26 ++-- .../cust/src/memory/device/device_buffer.rs | 8 +- crates/cust/src/memory/device/device_slice.rs | 30 ++--- crates/cust/src/memory/malloc.rs | 10 +- crates/cust/src/memory/mod.rs | 14 +-- crates/cust/src/module.rs | 14 +-- crates/cust/src/stream.rs | 4 +- crates/cust_raw/Cargo.toml | 2 + crates/cust_raw/build/callbacks.rs | 117 ++++++++++++++++++ crates/cust_raw/build/main.rs | 32 ++++- crates/optix-sys/build/main.rs | 3 + crates/optix-sys/build/optix_sdk.rs | 13 +- crates/optix/examples/ex02_pipeline/build.rs | 2 +- crates/optix/examples/ex03_window/build.rs | 2 +- examples/cuda/cpu/add/Cargo.toml | 2 - 25 files changed, 289 insertions(+), 162 deletions(-) create mode 100644 crates/cust_raw/build/callbacks.rs diff --git a/crates/blastoff/src/context.rs b/crates/blastoff/src/context.rs index 421e2d49..2a517be0 100644 --- a/crates/blastoff/src/context.rs +++ b/crates/blastoff/src/context.rs @@ -92,8 +92,8 @@ impl CublasContext { pub fn new() -> Result { let mut raw = MaybeUninit::uninit(); unsafe { - cublas_sys::cublasCreate_v2(raw.as_mut_ptr()).to_result()?; - cublas_sys::cublasSetPointerMode_v2( + cublas_sys::cublasCreate(raw.as_mut_ptr()).to_result()?; + cublas_sys::cublasSetPointerMode( raw.assume_init(), cublas_sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE, ) @@ -112,7 +112,7 @@ impl CublasContext { unsafe { let inner = mem::replace(&mut ctx.raw, ptr::null_mut()); - match cublas_sys::cublasDestroy_v2(inner).to_result() { + match cublas_sys::cublasDestroy(inner).to_result() { Ok(()) => { mem::forget(ctx); Ok(()) @@ -127,7 +127,7 @@ impl CublasContext { let mut raw = MaybeUninit::::uninit(); unsafe { // getVersion can't fail - cublas_sys::cublasGetVersion_v2(self.raw, raw.as_mut_ptr().cast()) + cublas_sys::cublasGetVersion(self.raw, raw.as_mut_ptr().cast()) .to_result() .unwrap(); @@ -145,7 +145,7 @@ impl CublasContext { ) -> Result { unsafe { // cudaStream_t is the same as CUstream - cublas_sys::cublasSetStream_v2( + cublas_sys::cublasSetStream( self.raw, mem::transmute::<*mut driver_sys::CUstream_st, *mut cublas_sys::CUstream_st>( stream.as_inner(), @@ -155,7 +155,7 @@ impl CublasContext { let res = func(self)?; // reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to // execute a raw sys function with the context's handle. - cublas_sys::cublasSetStream_v2(self.raw, ptr::null_mut()).to_result()?; + cublas_sys::cublasSetStream(self.raw, ptr::null_mut()).to_result()?; Ok(res) } } @@ -340,7 +340,7 @@ impl CublasContext { impl Drop for CublasContext { fn drop(&mut self) { unsafe { - cublas_sys::cublasDestroy_v2(self.raw); + cublas_sys::cublasDestroy(self.raw); } } } diff --git a/crates/blastoff/src/raw/level1.rs b/crates/blastoff/src/raw/level1.rs index d2167fbb..7268d642 100644 --- a/crates/blastoff/src/raw/level1.rs +++ b/crates/blastoff/src/raw/level1.rs @@ -103,7 +103,7 @@ impl Level1 for f32 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIsamax_v2(handle, n, x, incx, result) + cublasIsamax(handle, n, x, incx, result) } unsafe fn amin( handle: cublasHandle_t, @@ -112,7 +112,7 @@ impl Level1 for f32 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIsamin_v2(handle, n, x, incx, result) + cublasIsamin(handle, n, x, incx, result) } unsafe fn axpy( handle: cublasHandle_t, @@ -123,7 +123,7 @@ impl Level1 for f32 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasSaxpy_v2(handle, n, alpha, x, incx, y, incy) + cublasSaxpy(handle, n, alpha, x, incx, y, incy) } unsafe fn copy( handle: cublasHandle_t, @@ -133,7 +133,7 @@ impl Level1 for f32 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasScopy_v2(handle, n, x, incx, y, incy) + cublasScopy(handle, n, x, incx, y, incy) } unsafe fn nrm2( handle: cublasHandle_t, @@ -142,7 +142,7 @@ impl Level1 for f32 { incx: c_int, result: *mut Self::FloatTy, ) -> cublasStatus_t { - cublasSnrm2_v2(handle, n, x, incx, result) + cublasSnrm2(handle, n, x, incx, result) } unsafe fn rot( handle: cublasHandle_t, @@ -154,7 +154,7 @@ impl Level1 for f32 { c: *const Self::FloatTy, s: *const Self, ) -> cublasStatus_t { - cublasSrot_v2(handle, n, x, incx, y, incy, c, s) + cublasSrot(handle, n, x, incx, y, incy, c, s) } unsafe fn rotg( handle: cublasHandle_t, @@ -163,7 +163,7 @@ impl Level1 for f32 { c: *mut Self::FloatTy, s: *mut Self, ) -> cublasStatus_t { - cublasSrotg_v2(handle, a, b, c, s) + cublasSrotg(handle, a, b, c, s) } unsafe fn rotm( handle: cublasHandle_t, @@ -174,7 +174,7 @@ impl Level1 for f32 { incy: c_int, param: *const Self::FloatTy, ) -> cublasStatus_t { - cublasSrotm_v2(handle, n, x, incx, y, incy, param) + cublasSrotm(handle, n, x, incx, y, incy, param) } unsafe fn rotmg( handle: cublasHandle_t, @@ -184,7 +184,7 @@ impl Level1 for f32 { y1: *const Self, param: *mut Self, ) -> cublasStatus_t { - cublasSrotmg_v2(handle, d1, d2, x1, y1, param) + cublasSrotmg(handle, d1, d2, x1, y1, param) } unsafe fn scal( handle: cublasHandle_t, @@ -193,7 +193,7 @@ impl Level1 for f32 { x: *mut Self, incx: c_int, ) -> cublasStatus_t { - cublasSscal_v2(handle, n, alpha, x, incx) + cublasSscal(handle, n, alpha, x, incx) } unsafe fn swap( handle: cublasHandle_t, @@ -203,7 +203,7 @@ impl Level1 for f32 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasSswap_v2(handle, n, x, incx, y, incy) + cublasSswap(handle, n, x, incx, y, incy) } } @@ -215,7 +215,7 @@ impl Level1 for f64 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIdamax_v2(handle, n, x, incx, result) + cublasIdamax(handle, n, x, incx, result) } unsafe fn amin( handle: cublasHandle_t, @@ -224,7 +224,7 @@ impl Level1 for f64 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIdamin_v2(handle, n, x, incx, result) + cublasIdamin(handle, n, x, incx, result) } unsafe fn axpy( handle: cublasHandle_t, @@ -235,7 +235,7 @@ impl Level1 for f64 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasDaxpy_v2(handle, n, alpha, x, incx, y, incy) + cublasDaxpy(handle, n, alpha, x, incx, y, incy) } unsafe fn copy( handle: cublasHandle_t, @@ -245,7 +245,7 @@ impl Level1 for f64 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasDcopy_v2(handle, n, x, incx, y, incy) + cublasDcopy(handle, n, x, incx, y, incy) } unsafe fn nrm2( handle: cublasHandle_t, @@ -254,7 +254,7 @@ impl Level1 for f64 { incx: c_int, result: *mut Self::FloatTy, ) -> cublasStatus_t { - cublasDnrm2_v2(handle, n, x, incx, result) + cublasDnrm2(handle, n, x, incx, result) } unsafe fn rot( handle: cublasHandle_t, @@ -266,7 +266,7 @@ impl Level1 for f64 { c: *const Self::FloatTy, s: *const Self, ) -> cublasStatus_t { - cublasDrot_v2(handle, n, x, incx, y, incy, c, s) + cublasDrot(handle, n, x, incx, y, incy, c, s) } unsafe fn rotg( handle: cublasHandle_t, @@ -275,7 +275,7 @@ impl Level1 for f64 { c: *mut Self::FloatTy, s: *mut Self, ) -> cublasStatus_t { - cublasDrotg_v2(handle, a, b, c, s) + cublasDrotg(handle, a, b, c, s) } unsafe fn rotm( handle: cublasHandle_t, @@ -286,7 +286,7 @@ impl Level1 for f64 { incy: c_int, param: *const Self::FloatTy, ) -> cublasStatus_t { - cublasDrotm_v2(handle, n, x, incx, y, incy, param) + cublasDrotm(handle, n, x, incx, y, incy, param) } unsafe fn rotmg( handle: cublasHandle_t, @@ -296,7 +296,7 @@ impl Level1 for f64 { y1: *const Self, param: *mut Self, ) -> cublasStatus_t { - cublasDrotmg_v2(handle, d1, d2, x1, y1, param) + cublasDrotmg(handle, d1, d2, x1, y1, param) } unsafe fn scal( handle: cublasHandle_t, @@ -305,7 +305,7 @@ impl Level1 for f64 { x: *mut Self, incx: c_int, ) -> cublasStatus_t { - cublasDscal_v2(handle, n, alpha, x, incx) + cublasDscal(handle, n, alpha, x, incx) } unsafe fn swap( handle: cublasHandle_t, @@ -315,7 +315,7 @@ impl Level1 for f64 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasDswap_v2(handle, n, x, incx, y, incy) + cublasDswap(handle, n, x, incx, y, incy) } } @@ -327,7 +327,7 @@ impl Level1 for Complex32 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIcamax_v2(handle, n, x.cast(), incx, result) + cublasIcamax(handle, n, x.cast(), incx, result) } unsafe fn amin( handle: cublasHandle_t, @@ -336,7 +336,7 @@ impl Level1 for Complex32 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIcamin_v2(handle, n, x.cast(), incx, result) + cublasIcamin(handle, n, x.cast(), incx, result) } unsafe fn axpy( handle: cublasHandle_t, @@ -347,7 +347,7 @@ impl Level1 for Complex32 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasCaxpy_v2(handle, n, alpha.cast(), x.cast(), incx, y.cast(), incy) + cublasCaxpy(handle, n, alpha.cast(), x.cast(), incx, y.cast(), incy) } unsafe fn copy( handle: cublasHandle_t, @@ -357,7 +357,7 @@ impl Level1 for Complex32 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasCcopy_v2(handle, n, x.cast(), incx, y.cast(), incy) + cublasCcopy(handle, n, x.cast(), incx, y.cast(), incy) } unsafe fn nrm2( handle: cublasHandle_t, @@ -366,7 +366,7 @@ impl Level1 for Complex32 { incx: c_int, result: *mut Self::FloatTy, ) -> cublasStatus_t { - cublasScnrm2_v2(handle, n, x.cast(), incx, result) + cublasScnrm2(handle, n, x.cast(), incx, result) } unsafe fn rot( handle: cublasHandle_t, @@ -378,7 +378,7 @@ impl Level1 for Complex32 { c: *const Self::FloatTy, s: *const Self::FloatTy, ) -> cublasStatus_t { - cublasCsrot_v2(handle, n, x.cast(), incx, y.cast(), incy, c, s) + cublasCsrot(handle, n, x.cast(), incx, y.cast(), incy, c, s) } unsafe fn rotg( handle: cublasHandle_t, @@ -387,7 +387,7 @@ impl Level1 for Complex32 { c: *mut Self::FloatTy, s: *mut Self, ) -> cublasStatus_t { - cublasCrotg_v2(handle, a.cast(), b.cast(), c, s.cast()) + cublasCrotg(handle, a.cast(), b.cast(), c, s.cast()) } unsafe fn rotm( _handle: cublasHandle_t, @@ -417,7 +417,7 @@ impl Level1 for Complex32 { x: *mut Self, incx: c_int, ) -> cublasStatus_t { - cublasCscal_v2(handle, n, alpha.cast(), x.cast(), incx) + cublasCscal(handle, n, alpha.cast(), x.cast(), incx) } unsafe fn swap( handle: cublasHandle_t, @@ -427,7 +427,7 @@ impl Level1 for Complex32 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasCswap_v2(handle, n, x.cast(), incx, y.cast(), incy) + cublasCswap(handle, n, x.cast(), incx, y.cast(), incy) } } @@ -439,7 +439,7 @@ impl Level1 for Complex64 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIzamax_v2(handle, n, x.cast(), incx, result) + cublasIzamax(handle, n, x.cast(), incx, result) } unsafe fn amin( handle: cublasHandle_t, @@ -448,7 +448,7 @@ impl Level1 for Complex64 { incx: c_int, result: *mut c_int, ) -> cublasStatus_t { - cublasIzamin_v2(handle, n, x.cast(), incx, result) + cublasIzamin(handle, n, x.cast(), incx, result) } unsafe fn axpy( handle: cublasHandle_t, @@ -459,7 +459,7 @@ impl Level1 for Complex64 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasZaxpy_v2(handle, n, alpha.cast(), x.cast(), incx, y.cast(), incy) + cublasZaxpy(handle, n, alpha.cast(), x.cast(), incx, y.cast(), incy) } unsafe fn copy( handle: cublasHandle_t, @@ -469,7 +469,7 @@ impl Level1 for Complex64 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasZcopy_v2(handle, n, x.cast(), incx, y.cast(), incy) + cublasZcopy(handle, n, x.cast(), incx, y.cast(), incy) } unsafe fn nrm2( handle: cublasHandle_t, @@ -478,7 +478,7 @@ impl Level1 for Complex64 { incx: c_int, result: *mut Self::FloatTy, ) -> cublasStatus_t { - cublasDznrm2_v2(handle, n, x.cast(), incx, result) + cublasDznrm2(handle, n, x.cast(), incx, result) } unsafe fn rot( handle: cublasHandle_t, @@ -490,7 +490,7 @@ impl Level1 for Complex64 { c: *const Self::FloatTy, s: *const Self::FloatTy, ) -> cublasStatus_t { - cublasZdrot_v2(handle, n, x.cast(), incx, y.cast(), incy, c, s) + cublasZdrot(handle, n, x.cast(), incx, y.cast(), incy, c, s) } unsafe fn rotg( handle: cublasHandle_t, @@ -499,7 +499,7 @@ impl Level1 for Complex64 { c: *mut Self::FloatTy, s: *mut Self, ) -> cublasStatus_t { - cublasZrotg_v2(handle, a.cast(), b.cast(), c, s.cast()) + cublasZrotg(handle, a.cast(), b.cast(), c, s.cast()) } unsafe fn rotm( _handle: cublasHandle_t, @@ -529,7 +529,7 @@ impl Level1 for Complex64 { x: *mut Self, incx: c_int, ) -> cublasStatus_t { - cublasZscal_v2(handle, n, alpha.cast(), x.cast(), incx) + cublasZscal(handle, n, alpha.cast(), x.cast(), incx) } unsafe fn swap( handle: cublasHandle_t, @@ -539,7 +539,7 @@ impl Level1 for Complex64 { y: *mut Self, incy: c_int, ) -> cublasStatus_t { - cublasZswap_v2(handle, n, x.cast(), incx, y.cast(), incy) + cublasZswap(handle, n, x.cast(), incx, y.cast(), incy) } } @@ -575,7 +575,7 @@ impl ComplexLevel1 for Complex32 { incy: c_int, result: *mut Self, ) -> cublasStatus_t { - cublasCdotu_v2(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) + cublasCdotu(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) } unsafe fn dotc( handle: cublasHandle_t, @@ -586,7 +586,7 @@ impl ComplexLevel1 for Complex32 { incy: c_int, result: *mut Self, ) -> cublasStatus_t { - cublasCdotc_v2(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) + cublasCdotc(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) } } @@ -600,7 +600,7 @@ impl ComplexLevel1 for Complex64 { incy: c_int, result: *mut Self, ) -> cublasStatus_t { - cublasZdotu_v2(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) + cublasZdotu(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) } unsafe fn dotc( handle: cublasHandle_t, @@ -611,7 +611,7 @@ impl ComplexLevel1 for Complex64 { incy: c_int, result: *mut Self, ) -> cublasStatus_t { - cublasZdotc_v2(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) + cublasZdotc(handle, n, x.cast(), incx, y.cast(), incy, result.cast()) } } @@ -638,7 +638,7 @@ impl FloatLevel1 for f32 { incy: c_int, result: *mut Self, ) -> cublasStatus_t { - cublasSdot_v2(handle, n, x, incx, y, incy, result) + cublasSdot(handle, n, x, incx, y, incy, result) } } @@ -652,6 +652,6 @@ impl FloatLevel1 for f64 { incy: c_int, result: *mut Self, ) -> cublasStatus_t { - cublasDdot_v2(handle, n, x, incx, y, incy, result) + cublasDdot(handle, n, x, incx, y, incy, result) } } diff --git a/crates/blastoff/src/raw/level3.rs b/crates/blastoff/src/raw/level3.rs index 5e6d8e17..3e770a29 100644 --- a/crates/blastoff/src/raw/level3.rs +++ b/crates/blastoff/src/raw/level3.rs @@ -85,7 +85,7 @@ impl GemmOps for f32 { c: *mut Self, ldc: c_int, ) -> cublasStatus_t { - cublasSgemm_v2( + cublasSgemm( handle, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, ) } @@ -108,7 +108,7 @@ impl GemmOps for f64 { c: *mut Self, ldc: c_int, ) -> cublasStatus_t { - cublasDgemm_v2( + cublasDgemm( handle, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, ) } @@ -131,7 +131,7 @@ impl GemmOps for Complex32 { c: *mut Self, ldc: c_int, ) -> cublasStatus_t { - cublasCgemm_v2( + cublasCgemm( handle, transa, transb, @@ -167,7 +167,7 @@ impl GemmOps for Complex64 { c: *mut Self, ldc: c_int, ) -> cublasStatus_t { - cublasCgemm_v2( + cublasCgemm( handle, transa, transb, diff --git a/crates/cust/src/context/legacy.rs b/crates/cust/src/context/legacy.rs index 838d6d77..53f2b501 100644 --- a/crates/cust/src/context/legacy.rs +++ b/crates/cust/src/context/legacy.rs @@ -262,7 +262,7 @@ impl Context { // lifetime guarantees so we create-and-push, then pop, then the programmer has to // push again. let mut ctx: CUcontext = ptr::null_mut(); - driver_sys::cuCtxCreate_v2(&mut ctx as *mut CUcontext, flags.bits(), device.as_raw()) + driver_sys::cuCtxCreate(&mut ctx as *mut CUcontext, flags.bits(), device.as_raw()) .to_result()?; Ok(Context { inner: ctx }) } @@ -354,7 +354,7 @@ impl Context { unsafe { let inner = mem::replace(&mut ctx.inner, ptr::null_mut()); - match driver_sys::cuCtxDestroy_v2(inner).to_result() { + match driver_sys::cuCtxDestroy(inner).to_result() { Ok(()) => { mem::forget(ctx); Ok(()) @@ -372,7 +372,7 @@ impl Drop for Context { unsafe { let inner = mem::replace(&mut self.inner, ptr::null_mut()); - driver_sys::cuCtxDestroy_v2(inner); + driver_sys::cuCtxDestroy(inner); } } } @@ -456,7 +456,7 @@ impl ContextStack { pub fn pop() -> CudaResult { unsafe { let mut ctx: CUcontext = ptr::null_mut(); - driver_sys::cuCtxPopCurrent_v2(&mut ctx as *mut CUcontext).to_result()?; + driver_sys::cuCtxPopCurrent(&mut ctx as *mut CUcontext).to_result()?; Ok(UnownedContext { inner: ctx }) } } @@ -481,7 +481,7 @@ impl ContextStack { /// ``` pub fn push(ctx: &C) -> CudaResult<()> { unsafe { - driver_sys::cuCtxPushCurrent_v2(ctx.get_inner()).to_result()?; + driver_sys::cuCtxPushCurrent(ctx.get_inner()).to_result()?; Ok(()) } } diff --git a/crates/cust/src/context/mod.rs b/crates/cust/src/context/mod.rs index 6b2551bd..eb67e28a 100644 --- a/crates/cust/src/context/mod.rs +++ b/crates/cust/src/context/mod.rs @@ -215,13 +215,13 @@ impl Context { /// Nothing else should be using the primary context for this device, otherwise, /// spurious errors or segfaults will occur. pub unsafe fn reset(device: &Device) -> CudaResult<()> { - driver_sys::cuDevicePrimaryCtxReset_v2(device.as_raw()).to_result() + driver_sys::cuDevicePrimaryCtxReset(device.as_raw()).to_result() } /// Sets the flags for the device context, these flags will apply to any user of the primary /// context associated with this device. pub fn set_flags(&self, flags: ContextFlags) -> CudaResult<()> { - unsafe { driver_sys::cuDevicePrimaryCtxSetFlags_v2(self.device, flags.bits()).to_result() } + unsafe { driver_sys::cuDevicePrimaryCtxSetFlags(self.device, flags.bits()).to_result() } } /// Returns the raw handle to this context. @@ -291,7 +291,7 @@ impl Context { unsafe { let inner = mem::replace(&mut ctx.inner, ptr::null_mut()); - match driver_sys::cuDevicePrimaryCtxRelease_v2(ctx.device).to_result() { + match driver_sys::cuDevicePrimaryCtxRelease(ctx.device).to_result() { Ok(()) => { mem::forget(ctx); Ok(()) @@ -316,7 +316,7 @@ impl Drop for Context { unsafe { self.inner = ptr::null_mut(); - driver_sys::cuDevicePrimaryCtxRelease_v2(self.device); + driver_sys::cuDevicePrimaryCtxRelease(self.device); } } } diff --git a/crates/cust/src/device.rs b/crates/cust/src/device.rs index fb345c86..36f0cc76 100644 --- a/crates/cust/src/device.rs +++ b/crates/cust/src/device.rs @@ -295,7 +295,7 @@ impl Device { pub fn total_memory(self) -> CudaResult { unsafe { let mut memory = 0; - driver_sys::cuDeviceTotalMem_v2(&mut memory as *mut usize, self.device).to_result()?; + driver_sys::cuDeviceTotalMem(&mut memory as *mut usize, self.device).to_result()?; Ok(memory) } } diff --git a/crates/cust/src/event.rs b/crates/cust/src/event.rs index 55ed8195..18c36059 100644 --- a/crates/cust/src/event.rs +++ b/crates/cust/src/event.rs @@ -18,7 +18,7 @@ use std::ptr; use std::time::Duration; use cust_raw::driver_sys::{ - cuEventCreate, cuEventDestroy_v2, cuEventElapsedTime, cuEventQuery, cuEventRecord, + cuEventCreate, cuEventDestroy, cuEventElapsedTime, cuEventQuery, cuEventRecord, cuEventSynchronize, CUevent, }; @@ -334,7 +334,7 @@ impl Event { unsafe { let inner = mem::replace(&mut event.0, ptr::null_mut()); - match cuEventDestroy_v2(inner).to_result() { + match cuEventDestroy(inner).to_result() { Ok(()) => { mem::forget(event); Ok(()) @@ -347,7 +347,7 @@ impl Event { impl Drop for Event { fn drop(&mut self) { - unsafe { cuEventDestroy_v2(self.0) }; + unsafe { cuEventDestroy(self.0) }; } } diff --git a/crates/cust/src/graph.rs b/crates/cust/src/graph.rs index 914f42cf..b24e0963 100644 --- a/crates/cust/src/graph.rs +++ b/crates/cust/src/graph.rs @@ -395,7 +395,7 @@ impl Graph { let deps_ptr = deps.as_ptr().cast(); let mut node = MaybeUninit::::uninit(); let params = invocation.to_raw(); - driver_sys::cuGraphAddKernelNode_v2( + driver_sys::cuGraphAddKernelNode( node.as_mut_ptr().cast(), self.raw, deps_ptr, @@ -476,7 +476,7 @@ impl Graph { ); unsafe { let mut params = MaybeUninit::uninit(); - driver_sys::cuGraphKernelNodeGetParams_v2(node.to_raw(), params.as_mut_ptr()); + driver_sys::cuGraphKernelNodeGetParams(node.to_raw(), params.as_mut_ptr()); Ok(KernelInvocation::from_raw(params.assume_init())) } } diff --git a/crates/cust/src/link.rs b/crates/cust/src/link.rs index 26bf7202..d57aaf97 100644 --- a/crates/cust/src/link.rs +++ b/crates/cust/src/link.rs @@ -27,7 +27,7 @@ impl Linker { unsafe { let mut raw = MaybeUninit::uninit(); - driver_sys::cuLinkCreate_v2(0, null_mut(), null_mut(), raw.as_mut_ptr()).to_result()?; + driver_sys::cuLinkCreate(0, null_mut(), null_mut(), raw.as_mut_ptr()).to_result()?; Ok(Self { raw: raw.assume_init(), }) @@ -48,7 +48,7 @@ impl Linker { let ptx = ptx.as_ref(); unsafe { - driver_sys::cuLinkAddData_v2( + driver_sys::cuLinkAddData( self.raw, driver_sys::CUjitInputType::CU_JIT_INPUT_PTX, // cuda_sys wants *mut but from the API docs we know we retain ownership so @@ -73,7 +73,7 @@ impl Linker { let cubin = cubin.as_ref(); unsafe { - driver_sys::cuLinkAddData_v2( + driver_sys::cuLinkAddData( self.raw, driver_sys::CUjitInputType::CU_JIT_INPUT_CUBIN, // cuda_sys wants *mut but from the API docs we know we retain ownership so @@ -98,7 +98,7 @@ impl Linker { let fatbin = fatbin.as_ref(); unsafe { - driver_sys::cuLinkAddData_v2( + driver_sys::cuLinkAddData( self.raw, driver_sys::CUjitInputType::CU_JIT_INPUT_FATBINARY, // cuda_sys wants *mut but from the API docs we know we retain ownership so diff --git a/crates/cust/src/memory/array.rs b/crates/cust/src/memory/array.rs index 7d543e0c..01008633 100644 --- a/crates/cust/src/memory/array.rs +++ b/crates/cust/src/memory/array.rs @@ -14,9 +14,9 @@ use std::ptr::null; use std::ptr::null_mut; use cust_raw::driver_sys; -use cust_raw::driver_sys::cuMemcpy2D_v2; -use cust_raw::driver_sys::cuMemcpyAtoH_v2; -use cust_raw::driver_sys::cuMemcpyHtoA_v2; +use cust_raw::driver_sys::cuMemcpy2D; +use cust_raw::driver_sys::cuMemcpyAtoH; +use cust_raw::driver_sys::cuMemcpyHtoA; use cust_raw::driver_sys::CUDA_MEMCPY2D; use cust_raw::driver_sys::{CUarray, CUarray_format, CUarray_format_enum}; @@ -479,7 +479,7 @@ impl ArrayObject { } let mut handle = MaybeUninit::uninit(); - unsafe { driver_sys::cuArray3DCreate_v2(handle.as_mut_ptr(), &descriptor.desc) } + unsafe { driver_sys::cuArray3DCreate(handle.as_mut_ptr(), &descriptor.desc) } .to_result()?; Ok(Self { handle: unsafe { handle.assume_init() }, @@ -731,7 +731,7 @@ impl ArrayObject { pub fn descriptor(&self) -> CudaResult { // Use "zeroed" incase CUDA_ARRAY3D_DESCRIPTOR has uninitialized padding let mut raw_descriptor = MaybeUninit::zeroed(); - unsafe { driver_sys::cuArray3DGetDescriptor_v2(raw_descriptor.as_mut_ptr(), self.handle) } + unsafe { driver_sys::cuArray3DGetDescriptor(raw_descriptor.as_mut_ptr(), self.handle) } .to_result()?; Ok(ArrayDescriptor::from_raw(unsafe { @@ -764,8 +764,7 @@ impl ArrayObject { assert_eq!(self_size, other_size, "Array and value sizes don't match"); unsafe { if desc.height() == 0 && desc.depth() == 0 { - cuMemcpyHtoA_v2(self.handle, 0, val.as_ptr() as *const c_void, self_size) - .to_result() + cuMemcpyHtoA(self.handle, 0, val.as_ptr() as *const c_void, self_size).to_result() } else if desc.depth() == 0 { let desc = CUDA_MEMCPY2D { Height: desc.height(), @@ -787,7 +786,7 @@ impl ArrayObject { srcXInBytes: 0, srcY: 0, }; - cuMemcpy2D_v2(&desc as *const _).to_result() + cuMemcpy2D(&desc as *const _).to_result() } else { panic!(); } @@ -810,8 +809,7 @@ impl ArrayObject { assert_eq!(self_size, other_size, "Array and value sizes don't match"); unsafe { if desc.height() == 0 && desc.depth() == 0 { - cuMemcpyAtoH_v2(val.as_mut_ptr() as *mut c_void, self.handle, 0, self_size) - .to_result() + cuMemcpyAtoH(val.as_mut_ptr() as *mut c_void, self.handle, 0, self_size).to_result() } else if desc.depth() == 0 { let width = desc.width() * desc.num_channels() as usize * desc.format().mem_size(); let desc = CUDA_MEMCPY2D { @@ -832,7 +830,7 @@ impl ArrayObject { srcXInBytes: 0, srcY: 0, }; - cuMemcpy2D_v2(&desc as *const _).to_result()?; + cuMemcpy2D(&desc as *const _).to_result()?; Ok(()) } else { panic!(); diff --git a/crates/cust/src/memory/device/device_box.rs b/crates/cust/src/memory/device/device_box.rs index acb0d040..b5e765df 100644 --- a/crates/cust/src/memory/device/device_box.rs +++ b/crates/cust/src/memory/device/device_box.rs @@ -164,7 +164,7 @@ impl DeviceBox { unsafe { let new_box = DeviceBox::uninitialized()?; if mem::size_of::() != 0 { - driver_sys::cuMemsetD8_v2(new_box.as_device_ptr().as_raw(), 0, mem::size_of::()) + driver_sys::cuMemsetD8(new_box.as_device_ptr().as_raw(), 0, mem::size_of::()) .to_result()?; } Ok(new_box) @@ -430,12 +430,8 @@ impl CopyDestination for DeviceBox { let size = mem::size_of::(); if size != 0 { unsafe { - driver_sys::cuMemcpyHtoD_v2( - self.ptr.as_raw(), - val as *const T as *const c_void, - size, - ) - .to_result()? + driver_sys::cuMemcpyHtoD(self.ptr.as_raw(), val as *const T as *const c_void, size) + .to_result()? } } Ok(()) @@ -445,7 +441,7 @@ impl CopyDestination for DeviceBox { let size = mem::size_of::(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoH_v2(val as *const T as *mut c_void, self.ptr.as_raw(), size) + driver_sys::cuMemcpyDtoH(val as *const T as *mut c_void, self.ptr.as_raw(), size) .to_result()? } } @@ -457,8 +453,7 @@ impl CopyDestination> for DeviceBox { let size = mem::size_of::(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoD_v2(self.ptr.as_raw(), val.ptr.as_raw(), size) - .to_result()? + driver_sys::cuMemcpyDtoD(self.ptr.as_raw(), val.ptr.as_raw(), size).to_result()? } } Ok(()) @@ -468,8 +463,7 @@ impl CopyDestination> for DeviceBox { let size = mem::size_of::(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoD_v2(val.ptr.as_raw(), self.ptr.as_raw(), size) - .to_result()? + driver_sys::cuMemcpyDtoD(val.ptr.as_raw(), self.ptr.as_raw(), size).to_result()? } } Ok(()) @@ -479,7 +473,7 @@ impl AsyncCopyDestination for DeviceBox { unsafe fn async_copy_from(&mut self, val: &T, stream: &Stream) -> CudaResult<()> { let size = mem::size_of::(); if size != 0 { - driver_sys::cuMemcpyHtoDAsync_v2( + driver_sys::cuMemcpyHtoDAsync( self.ptr.as_raw(), val as *const _ as *const c_void, size, @@ -493,7 +487,7 @@ impl AsyncCopyDestination for DeviceBox { unsafe fn async_copy_to(&self, val: &mut T, stream: &Stream) -> CudaResult<()> { let size = mem::size_of::(); if size != 0 { - driver_sys::cuMemcpyDtoHAsync_v2( + driver_sys::cuMemcpyDtoHAsync( val as *mut _ as *mut c_void, self.ptr.as_raw(), size, @@ -508,7 +502,7 @@ impl AsyncCopyDestination> for DeviceBox { unsafe fn async_copy_from(&mut self, val: &DeviceBox, stream: &Stream) -> CudaResult<()> { let size = mem::size_of::(); if size != 0 { - driver_sys::cuMemcpyDtoDAsync_v2( + driver_sys::cuMemcpyDtoDAsync( self.ptr.as_raw(), val.ptr.as_raw(), size, @@ -522,7 +516,7 @@ impl AsyncCopyDestination> for DeviceBox { unsafe fn async_copy_to(&self, val: &mut DeviceBox, stream: &Stream) -> CudaResult<()> { let size = mem::size_of::(); if size != 0 { - driver_sys::cuMemcpyDtoDAsync_v2( + driver_sys::cuMemcpyDtoDAsync( val.ptr.as_raw(), self.ptr.as_raw(), size, diff --git a/crates/cust/src/memory/device/device_buffer.rs b/crates/cust/src/memory/device/device_buffer.rs index 6fc5dde6..873a194c 100644 --- a/crates/cust/src/memory/device/device_buffer.rs +++ b/crates/cust/src/memory/device/device_buffer.rs @@ -232,12 +232,8 @@ impl DeviceBuffer { unsafe { let new_buf = DeviceBuffer::uninitialized(size)?; if size_of::() != 0 { - driver_sys::cuMemsetD8_v2( - new_buf.as_device_ptr().as_raw(), - 0, - size_of::() * size, - ) - .to_result()?; + driver_sys::cuMemsetD8(new_buf.as_device_ptr().as_raw(), 0, size_of::() * size) + .to_result()?; } Ok(new_buf) } diff --git a/crates/cust/src/memory/device/device_slice.rs b/crates/cust/src/memory/device/device_slice.rs index 702b9d04..893b8c4d 100644 --- a/crates/cust/src/memory/device/device_slice.rs +++ b/crates/cust/src/memory/device/device_slice.rs @@ -250,7 +250,7 @@ impl DeviceSlice { // SAFETY: We know T can hold any value because it is `Pod`, and // sub-byte alignment isn't a thing so we know the alignment is right. unsafe { - driver_sys::cuMemsetD8_v2(self.as_raw_ptr(), value, self.size_in_bytes()).to_result() + driver_sys::cuMemsetD8(self.as_raw_ptr(), value, self.size_in_bytes()).to_result() } } @@ -300,7 +300,7 @@ impl DeviceSlice { 0, "Buffer pointer is not aligned to at least 2 bytes!" ); - unsafe { driver_sys::cuMemsetD16_v2(self.as_raw_ptr(), value, data_len / 2).to_result() } + unsafe { driver_sys::cuMemsetD16(self.as_raw_ptr(), value, data_len / 2).to_result() } } /// Sets the memory range of this buffer to contiguous `16-bit` values of `value` asynchronously. @@ -358,7 +358,7 @@ impl DeviceSlice { 0, "Buffer pointer is not aligned to at least 4 bytes!" ); - unsafe { driver_sys::cuMemsetD32_v2(self.as_raw_ptr(), value, data_len / 4).to_result() } + unsafe { driver_sys::cuMemsetD32(self.as_raw_ptr(), value, data_len / 4).to_result() } } /// Sets the memory range of this buffer to contiguous `32-bit` values of `value` asynchronously. @@ -651,7 +651,7 @@ impl + AsMut<[T]> + ?Sized> CopyDestination for let size = self.size_in_bytes(); if size != 0 { unsafe { - driver_sys::cuMemcpyHtoD_v2(self.as_raw_ptr(), val.as_ptr() as *const c_void, size) + driver_sys::cuMemcpyHtoD(self.as_raw_ptr(), val.as_ptr() as *const c_void, size) .to_result()? } } @@ -667,12 +667,8 @@ impl + AsMut<[T]> + ?Sized> CopyDestination for let size = self.size_in_bytes(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoH_v2( - val.as_mut_ptr() as *mut c_void, - self.as_raw_ptr(), - size, - ) - .to_result()? + driver_sys::cuMemcpyDtoH(val.as_mut_ptr() as *mut c_void, self.as_raw_ptr(), size) + .to_result()? } } Ok(()) @@ -687,8 +683,7 @@ impl CopyDestination> for DeviceSlice { let size = self.size_in_bytes(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoD_v2(self.as_raw_ptr(), val.as_raw_ptr(), size) - .to_result()? + driver_sys::cuMemcpyDtoD(self.as_raw_ptr(), val.as_raw_ptr(), size).to_result()? } } Ok(()) @@ -702,8 +697,7 @@ impl CopyDestination> for DeviceSlice { let size = self.size_in_bytes(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoD_v2(val.as_raw_ptr(), self.as_raw_ptr(), size) - .to_result()? + driver_sys::cuMemcpyDtoD(val.as_raw_ptr(), self.as_raw_ptr(), size).to_result()? } } Ok(()) @@ -729,7 +723,7 @@ impl + AsMut<[T]> + ?Sized> AsyncCopyDestination ); let size = self.size_in_bytes(); if size != 0 { - driver_sys::cuMemcpyHtoDAsync_v2( + driver_sys::cuMemcpyHtoDAsync( self.as_raw_ptr(), val.as_ptr() as *const c_void, size, @@ -748,7 +742,7 @@ impl + AsMut<[T]> + ?Sized> AsyncCopyDestination ); let size = self.size_in_bytes(); if size != 0 { - driver_sys::cuMemcpyDtoHAsync_v2( + driver_sys::cuMemcpyDtoHAsync( val.as_mut_ptr() as *mut c_void, self.as_raw_ptr(), size, @@ -767,7 +761,7 @@ impl AsyncCopyDestination> for DeviceSlice { ); let size = self.size_in_bytes(); if size != 0 { - driver_sys::cuMemcpyDtoDAsync_v2( + driver_sys::cuMemcpyDtoDAsync( self.as_raw_ptr(), val.as_raw_ptr(), size, @@ -785,7 +779,7 @@ impl AsyncCopyDestination> for DeviceSlice { ); let size = self.size_in_bytes(); if size != 0 { - driver_sys::cuMemcpyDtoDAsync_v2( + driver_sys::cuMemcpyDtoDAsync( val.as_raw_ptr(), self.as_raw_ptr(), size, diff --git a/crates/cust/src/memory/malloc.rs b/crates/cust/src/memory/malloc.rs index 78f1f356..6255778c 100644 --- a/crates/cust/src/memory/malloc.rs +++ b/crates/cust/src/memory/malloc.rs @@ -48,7 +48,7 @@ pub unsafe fn cuda_malloc(count: usize) -> CudaResult( let mut ptr = 0; let mut pitch = 0; - driver_sys::cuMemAllocPitch_v2(&mut ptr, &mut pitch, width_bytes, height, element_size) + driver_sys::cuMemAllocPitch(&mut ptr, &mut pitch, width_bytes, height, element_size) .to_result()?; Ok((DevicePointer::from_raw(ptr), pitch)) } @@ -236,7 +236,7 @@ pub unsafe fn cuda_free(ptr: DevicePointer) -> CudaResult<()> return Err(CudaError::InvalidMemoryAllocation); } - driver_sys::cuMemFree_v2(ptr.as_raw()).to_result()?; + driver_sys::cuMemFree(ptr.as_raw()).to_result()?; Ok(()) } @@ -269,7 +269,7 @@ pub unsafe fn cuda_free_unified(mut p: UnifiedPointer) -> Cuda return Err(CudaError::InvalidMemoryAllocation); } - driver_sys::cuMemFree_v2(ptr as u64).to_result()?; + driver_sys::cuMemFree(ptr as u64).to_result()?; Ok(()) } @@ -311,7 +311,7 @@ pub unsafe fn cuda_malloc_locked(count: usize) -> CudaResult<*mut T> { } let mut ptr: *mut c_void = ptr::null_mut(); - driver_sys::cuMemAllocHost_v2(&mut ptr as *mut *mut c_void, size).to_result()?; + driver_sys::cuMemAllocHost(&mut ptr as *mut *mut c_void, size).to_result()?; let ptr = ptr as *mut T; Ok(ptr) } diff --git a/crates/cust/src/memory/mod.rs b/crates/cust/src/memory/mod.rs index d9fd4838..aa349145 100644 --- a/crates/cust/src/memory/mod.rs +++ b/crates/cust/src/memory/mod.rs @@ -205,25 +205,25 @@ mod private { impl Sealed for DeviceBox {} } -/// Simple wrapper over cuMemcpyHtoD_v2 +/// Simple wrapper over cuMemcpyHtoD #[allow(clippy::missing_safety_doc)] pub unsafe fn memcpy_htod( d_ptr: driver_sys::CUdeviceptr, src_ptr: *const c_void, size: usize, ) -> CudaResult<()> { - driver_sys::cuMemcpyHtoD_v2(d_ptr, src_ptr, size).to_result()?; + driver_sys::cuMemcpyHtoD(d_ptr, src_ptr, size).to_result()?; Ok(()) } -/// Simple wrapper over cuMemcpyDtoH_v2 +/// Simple wrapper over cuMemcpyDtoH #[allow(clippy::missing_safety_doc)] pub unsafe fn memcpy_dtoh( d_ptr: *mut c_void, src_ptr: driver_sys::CUdeviceptr, size: usize, ) -> CudaResult<()> { - driver_sys::cuMemcpyDtoH_v2(d_ptr, src_ptr, size).to_result()?; + driver_sys::cuMemcpyDtoH(d_ptr, src_ptr, size).to_result()?; Ok(()) } @@ -309,7 +309,7 @@ pub unsafe fn memcpy_2d_htod( Height: height, }; - driver_sys::cuMemcpy2D_v2(&pcopy).to_result()?; + driver_sys::cuMemcpy2D(&pcopy).to_result()?; Ok(()) } @@ -395,7 +395,7 @@ pub unsafe fn memcpy_2d_dtoh( Height: height, }; - driver_sys::cuMemcpy2D_v2(&pcopy).to_result()?; + driver_sys::cuMemcpy2D(&pcopy).to_result()?; Ok(()) } @@ -409,7 +409,7 @@ pub fn mem_get_info() -> CudaResult<(usize, usize)> { let mut mem_free = 0; let mut mem_total = 0; unsafe { - driver_sys::cuMemGetInfo_v2(&mut mem_free, &mut mem_total).to_result()?; + driver_sys::cuMemGetInfo(&mut mem_free, &mut mem_total).to_result()?; } Ok((mem_free, mem_total)) } diff --git a/crates/cust/src/module.rs b/crates/cust/src/module.rs index 815cbd5c..062a912b 100644 --- a/crates/cust/src/module.rs +++ b/crates/cust/src/module.rs @@ -340,7 +340,7 @@ impl Module { /// ``` #[deprecated( since = "0.3.0", - note = "load_from_string was an inconsistent name with inconsistent params, use from_ptx/from_ptx_cstr, passing + note = "load_from_string was an inconsistent name with inconsistent params, use from_ptx/from_ptx_cstr, passing an empty slice of options (usually) " )] @@ -390,7 +390,7 @@ impl Module { let mut ptr: DevicePointer = DevicePointer::null(); let mut size: usize = 0; - driver_sys::cuModuleGetGlobal_v2( + driver_sys::cuModuleGetGlobal( &mut ptr as *mut DevicePointer as *mut driver_sys::CUdeviceptr, &mut size as *mut usize, self.inner, @@ -513,12 +513,8 @@ impl CopyDestination for Symbol<'_, T> { let size = mem::size_of::(); if size != 0 { unsafe { - driver_sys::cuMemcpyHtoD_v2( - self.ptr.as_raw(), - val as *const T as *const c_void, - size, - ) - .to_result()? + driver_sys::cuMemcpyHtoD(self.ptr.as_raw(), val as *const T as *const c_void, size) + .to_result()? } } Ok(()) @@ -528,7 +524,7 @@ impl CopyDestination for Symbol<'_, T> { let size = mem::size_of::(); if size != 0 { unsafe { - driver_sys::cuMemcpyDtoH_v2(val as *const T as *mut c_void, self.ptr.as_raw(), size) + driver_sys::cuMemcpyDtoH(val as *const T as *mut c_void, self.ptr.as_raw(), size) .to_result()? } } diff --git a/crates/cust/src/stream.rs b/crates/cust/src/stream.rs index dc67119d..41404da6 100644 --- a/crates/cust/src/stream.rs +++ b/crates/cust/src/stream.rs @@ -325,7 +325,7 @@ impl Stream { unsafe { let inner = mem::replace(&mut stream.inner, ptr::null_mut()); - match driver_sys::cuStreamDestroy_v2(inner).to_result() { + match driver_sys::cuStreamDestroy(inner).to_result() { Ok(()) => { mem::forget(stream); Ok(()) @@ -344,7 +344,7 @@ impl Drop for Stream { unsafe { let inner = mem::replace(&mut self.inner, ptr::null_mut()); - driver_sys::cuStreamDestroy_v2(inner); + driver_sys::cuStreamDestroy(inner); } } } diff --git a/crates/cust_raw/Cargo.toml b/crates/cust_raw/Cargo.toml index 94c91911..046a713f 100644 --- a/crates/cust_raw/Cargo.toml +++ b/crates/cust_raw/Cargo.toml @@ -11,6 +11,8 @@ build = "build/main.rs" [build-dependencies] bindgen = "0.71.1" +bimap = "0.6.3" +cc = "1.2.17" [package.metadata.docs.rs] features = [ diff --git a/crates/cust_raw/build/callbacks.rs b/crates/cust_raw/build/callbacks.rs new file mode 100644 index 00000000..a15e4366 --- /dev/null +++ b/crates/cust_raw/build/callbacks.rs @@ -0,0 +1,117 @@ +use std::cell; +use std::fs; +use std::path; +use std::sync; + +use bimap; +use bindgen::callbacks::{ItemInfo, ItemKind, MacroParsingBehavior, ParseCallbacks}; + +/// Struct to handle renaming of functions through macro expansion. +#[derive(Debug)] +pub(crate) struct FunctionRenames { + func_prefix: &'static str, + out_dir: path::PathBuf, + includes: path::PathBuf, + include_dirs: Vec, + macro_names: cell::RefCell>, + func_remaps: sync::OnceLock>, +} + +impl FunctionRenames { + pub fn new, I: Into>( + func_prefix: &'static str, + out_dir: P, + includes: I, + include_dirs: Vec, + ) -> Self { + Self { + func_prefix, + out_dir: out_dir.as_ref().to_path_buf(), + includes: includes.into(), + include_dirs, + macro_names: cell::RefCell::new(Vec::new()), + func_remaps: sync::OnceLock::new(), + } + } + + fn record_macro(&self, name: &str) { + self.macro_names.borrow_mut().push(name.to_string()); + } + + fn expand(&self) -> &bimap::BiHashMap { + self.func_remaps.get_or_init(|| { + let expand_me = self.out_dir.join("expand_macros.c"); + let includes = fs::read_to_string(&self.includes) + .expect("Failed to read includes for function renames"); + + let mut template = format!( + r#"{includes} +#define RENAMED2(from, to) RUST_RENAMED##from##_TO_##to +#define RENAMED(from, to) RENAMED2(from, to) +"# + ); + + for name in self.macro_names.borrow().iter() { + template.push_str(&format!("RENAMED(_{name}, {name})\n")); + } + + { + let mut temp = fs::File::create(&expand_me).unwrap(); + std::io::Write::write_all(&mut temp, template.as_bytes()).unwrap(); + } + + let mut build = cc::Build::new(); + build + .file(&expand_me) + .includes(&self.include_dirs) + .cargo_warnings(false); + + let expanded = match build.try_expand() { + Ok(expanded) => expanded, + Err(e) => panic!("Failed to expand macros: {}", e), + }; + let expanded = str::from_utf8(&expanded).unwrap(); + + let mut remaps = bimap::BiHashMap::new(); + for line in expanded.lines().rev() { + let rename_prefix = "RUST_RENAMED_"; + + if let Some((original, expanded)) = line + .strip_prefix(rename_prefix) + .and_then(|s| s.split_once("_TO_")) + .filter(|(l, r)| l != r && !r.is_empty()) + { + remaps.insert(original.to_string(), expanded.to_string()); + } + } + + fs::remove_file(&expand_me).expect("Failed to remove temporary file"); + remaps + }) + } +} + +impl ParseCallbacks for FunctionRenames { + fn will_parse_macro(&self, name: &str) -> MacroParsingBehavior { + if name.starts_with(self.func_prefix) { + self.record_macro(name); + } + MacroParsingBehavior::Default + } + + fn generated_name_override(&self, item_info: ItemInfo<'_>) -> Option { + let remaps = self.expand(); + match item_info.kind { + ItemKind::Function => remaps.get_by_right(item_info.name).cloned(), + _ => None, + } + } + + fn generated_link_name_override(&self, item_info: ItemInfo<'_>) -> Option { + let remaps = self.expand(); + match item_info.kind { + ItemKind::Function => remaps.get_by_left(item_info.name).cloned(), + _ => None, + } + } +} diff --git a/crates/cust_raw/build/main.rs b/crates/cust_raw/build/main.rs index 7137b40d..61eddfda 100644 --- a/crates/cust_raw/build/main.rs +++ b/crates/cust_raw/build/main.rs @@ -2,6 +2,7 @@ use std::env; use std::fs; use std::path; +pub mod callbacks; pub mod cuda_sdk; fn main() { @@ -79,8 +80,15 @@ fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { return; } let bindgen_path = path::PathBuf::from(format!("{}/driver_sys.rs", outdir.display())); + let header = "build/driver_wrapper.h"; let bindings = bindgen::Builder::default() - .header("build/driver_wrapper.h") + .header(header) + .parse_callbacks(Box::new(callbacks::FunctionRenames::new( + "cu", + outdir, + header, + sdk.cuda_include_paths().to_owned(), + ))) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .clang_args( sdk.cuda_include_paths() @@ -115,8 +123,15 @@ fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { return; } let bindgen_path = path::PathBuf::from(format!("{}/runtime_sys.rs", outdir.display())); + let header = "build/runtime_wrapper.h"; let bindings = bindgen::Builder::default() - .header("build/runtime_wrapper.h") + .header(header) + .parse_callbacks(Box::new(callbacks::FunctionRenames::new( + "cuda", + outdir, + header, + sdk.cuda_include_paths().to_owned(), + ))) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .clang_args( sdk.cuda_include_paths() @@ -148,16 +163,23 @@ fn create_cublas_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { #[rustfmt::skip] let params = &[ (cfg!(feature = "cublas"), "cublas", "^cublas.*", "^CUBLAS.*"), - (cfg!(feature = "cublaslt"), "cublaslt", "^cublasLt.*", "^CUBLASLT.*"), - (cfg!(feature = "cublasxt"), "cublasxt", "^cublasXt.*", "^CUBLASXT.*"), + (cfg!(feature = "cublaslt"), "cublasLt", "^cublasLt.*", "^CUBLASLT.*"), + (cfg!(feature = "cublasxt"), "cublasXt", "^cublasXt.*", "^CUBLASXT.*"), ]; for (should_generate, pkg, tf, var) in params { if !should_generate { continue; } let bindgen_path = path::PathBuf::from(format!("{}/{pkg}_sys.rs", outdir.display())); + let header = format!("build/{pkg}_wrapper.h"); let bindings = bindgen::Builder::default() - .header(format!("build/{pkg}_wrapper.h")) + .header(&header) + .parse_callbacks(Box::new(callbacks::FunctionRenames::new( + pkg, + outdir, + header, + sdk.cuda_include_paths().to_owned(), + ))) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .clang_args( sdk.cuda_include_paths() diff --git a/crates/optix-sys/build/main.rs b/crates/optix-sys/build/main.rs index 078618f6..9301f6e6 100644 --- a/crates/optix-sys/build/main.rs +++ b/crates/optix-sys/build/main.rs @@ -15,6 +15,9 @@ fn main() { .expect("Cannot find transitive metadata 'cuda_include' from cust_raw package."); println!("cargo::rerun-if-changed=build"); + for e in sdk.related_optix_envs() { + println!("cargo::rerun-if-env-changed={}", e); + } // Emit metadata for the build script. println!("cargo::metadata=root={}", sdk.optix_root().display()); println!("cargo::metadata=version={}", sdk.optix_version()); diff --git a/crates/optix-sys/build/optix_sdk.rs b/crates/optix-sys/build/optix_sdk.rs index bc0cf736..46b7329f 100644 --- a/crates/optix-sys/build/optix_sdk.rs +++ b/crates/optix-sys/build/optix_sdk.rs @@ -3,6 +3,8 @@ use std::error; use std::fs; use std::path; +const OPTIX_ROOT_ENVS: &[&str] = &["OPTIX_ROOT", "OPTIX_ROOT_DIR"]; + /// Represents the OptiX SDK installation. #[derive(Debug, Clone)] pub struct OptiXSdk { @@ -60,14 +62,19 @@ impl OptiXSdk { self.optix_version % 100 } + pub fn related_optix_envs(&self) -> Vec { + OPTIX_ROOT_ENVS.iter().map(|s| s.to_string()).collect() + } + fn find_optix_root() -> Option { // the optix SDK installer sets OPTIX_ROOT_DIR whenever it installs. // We also check OPTIX_ROOT first in case someone wants to override it without overriding // the SDK-set variable. - env::var("OPTIX_ROOT") - .ok() - .or_else(|| env::var("OPTIX_ROOT_DIR").ok()) + OPTIX_ROOT_ENVS + .iter() + .filter_map(|env| env::var(env).ok()) .map(path::PathBuf::from) + .next() } /// Parses the content of the `optix.h` header file to extract the OptiX version. diff --git a/crates/optix/examples/ex02_pipeline/build.rs b/crates/optix/examples/ex02_pipeline/build.rs index 4e82edf0..4eb0f5b1 100644 --- a/crates/optix/examples/ex02_pipeline/build.rs +++ b/crates/optix/examples/ex02_pipeline/build.rs @@ -7,7 +7,7 @@ fn main() { println!("cargo::rerun-if-changed=build.rs"); let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); - let optix_include_paths = env::var_os("DEP_OPTIX_OPTIX_INCLUDE") + let optix_include_paths = env::var_os("DEP_OPTIX_INCLUDE_DIR") .map(|s| env::split_paths(s.as_os_str()).collect::>()) .expect("Cannot find transitive metadata 'optix_include' from optix-sys package."); diff --git a/crates/optix/examples/ex03_window/build.rs b/crates/optix/examples/ex03_window/build.rs index 06122b04..63d1bced 100644 --- a/crates/optix/examples/ex03_window/build.rs +++ b/crates/optix/examples/ex03_window/build.rs @@ -5,7 +5,7 @@ fn main() { println!("cargo::rerun-if-changed=build.rs"); let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); - let optix_include_paths = env::var_os("DEP_OPTIX_OPTIX_INCLUDE") + let optix_include_paths = env::var_os("DEP_OPTIX_INCLUDE_DIR") .map(|s| env::split_paths(s.as_os_str()).collect::>()) .expect("Cannot find transitive metadata 'optix_include' from optix-sys package."); diff --git a/examples/cuda/cpu/add/Cargo.toml b/examples/cuda/cpu/add/Cargo.toml index ae99bb78..523e0a75 100644 --- a/examples/cuda/cpu/add/Cargo.toml +++ b/examples/cuda/cpu/add/Cargo.toml @@ -14,8 +14,6 @@ log = "=0.4.17" regex-syntax = "=0.6.28" regex = "=1.11.1" thread_local = "=1.1.4" -jobserver = "=0.1.25" -cc = "=1.0.78" rayon = "=1.10" rayon-core = "=1.12.1" byteorder = "=1.4.0"