1- use crate :: { error:: * , sys} ;
2- use cust:: stream:: Stream ;
31use std:: ffi:: CString ;
42use std:: mem:: { self , MaybeUninit } ;
53use std:: os:: raw:: c_char;
64use std:: ptr;
75
8- type Result < T , E = Error > = std:: result:: Result < T , E > ;
6+ use cust:: stream:: Stream ;
7+ use cust_raw:: cublas_sys;
8+ use cust_raw:: driver_sys;
9+
10+ use super :: error:: DropResult ;
11+ use super :: error:: ToResult as _;
12+
13+ type Result < T , E = super :: error:: Error > = std:: result:: Result < T , E > ;
914
1015bitflags:: bitflags! {
1116 /// Configures precision levels for the math in cuBLAS.
12- #[ derive( Default ) ]
17+ #[ derive( Debug , Default , Clone , Copy , PartialEq , Eq , Hash ) ]
1318 pub struct MathMode : u32 {
1419 /// Highest performance mode which uses compute and intermediate storage precisions
1520 /// with at least the same number of mantissa and exponent bits as requested. Will
@@ -68,7 +73,7 @@ bitflags::bitflags! {
6873/// - [Matrix Multiplication <span style="float:right;">`gemm`</span>](CublasContext::gemm)
6974#[ derive( Debug ) ]
7075pub struct CublasContext {
71- pub ( crate ) raw : sys :: v2 :: cublasHandle_t ,
76+ pub ( crate ) raw : cublas_sys :: cublasHandle_t ,
7277}
7378
7479impl CublasContext {
@@ -87,10 +92,10 @@ impl CublasContext {
8792 pub fn new ( ) -> Result < Self > {
8893 let mut raw = MaybeUninit :: uninit ( ) ;
8994 unsafe {
90- sys :: v2 :: cublasCreate_v2 ( raw. as_mut_ptr ( ) ) . to_result ( ) ?;
91- sys :: v2 :: cublasSetPointerMode_v2 (
95+ cublas_sys :: cublasCreate_v2 ( raw. as_mut_ptr ( ) ) . to_result ( ) ?;
96+ cublas_sys :: cublasSetPointerMode_v2 (
9297 raw. assume_init ( ) ,
93- sys :: v2 :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE ,
98+ cublas_sys :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE ,
9499 )
95100 . to_result ( ) ?;
96101 Ok ( Self {
@@ -107,7 +112,7 @@ impl CublasContext {
107112
108113 unsafe {
109114 let inner = mem:: replace ( & mut ctx. raw , ptr:: null_mut ( ) ) ;
110- match sys :: v2 :: cublasDestroy_v2 ( inner) . to_result ( ) {
115+ match cublas_sys :: cublasDestroy_v2 ( inner) . to_result ( ) {
111116 Ok ( ( ) ) => {
112117 mem:: forget ( ctx) ;
113118 Ok ( ( ) )
@@ -122,7 +127,7 @@ impl CublasContext {
122127 let mut raw = MaybeUninit :: < u32 > :: uninit ( ) ;
123128 unsafe {
124129 // getVersion can't fail
125- sys :: v2 :: cublasGetVersion_v2 ( self . raw , raw. as_mut_ptr ( ) . cast ( ) )
130+ cublas_sys :: cublasGetVersion_v2 ( self . raw , raw. as_mut_ptr ( ) . cast ( ) )
126131 . to_result ( )
127132 . unwrap ( ) ;
128133
@@ -140,17 +145,17 @@ impl CublasContext {
140145 ) -> Result < T > {
141146 unsafe {
142147 // cudaStream_t is the same as CUstream
143- sys :: v2 :: cublasSetStream_v2 (
148+ cublas_sys :: cublasSetStream_v2 (
144149 self . raw ,
145- mem:: transmute :: < * mut cust :: sys :: CUstream_st , * mut cublas_sys:: v2 :: CUstream_st > (
150+ mem:: transmute :: < * mut driver_sys :: CUstream_st , * mut cublas_sys:: CUstream_st > (
146151 stream. as_inner ( ) ,
147152 ) ,
148153 )
149154 . to_result ( ) ?;
150155 let res = func ( self ) ?;
151156 // reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to
152157 // execute a raw sys function with the context's handle.
153- sys :: v2 :: cublasSetStream_v2 ( self . raw , ptr:: null_mut ( ) ) . to_result ( ) ?;
158+ cublas_sys :: cublasSetStream_v2 ( self . raw , ptr:: null_mut ( ) ) . to_result ( ) ?;
154159 Ok ( res)
155160 }
156161 }
@@ -180,12 +185,12 @@ impl CublasContext {
180185 /// ```
181186 pub fn set_atomics_mode ( & self , allowed : bool ) -> Result < ( ) > {
182187 unsafe {
183- Ok ( sys :: v2 :: cublasSetAtomicsMode (
188+ Ok ( cublas_sys :: cublasSetAtomicsMode (
184189 self . raw ,
185190 if allowed {
186- sys :: v2 :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED
191+ cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED
187192 } else {
188- sys :: v2 :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED
193+ cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED
189194 } ,
190195 )
191196 . to_result ( ) ?)
@@ -210,10 +215,10 @@ impl CublasContext {
210215 pub fn get_atomics_mode ( & self ) -> Result < bool > {
211216 let mut mode = MaybeUninit :: uninit ( ) ;
212217 unsafe {
213- sys :: v2 :: cublasGetAtomicsMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
218+ cublas_sys :: cublasGetAtomicsMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
214219 Ok ( match mode. assume_init ( ) {
215- sys :: v2 :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED => true ,
216- sys :: v2 :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED => false ,
220+ cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED => true ,
221+ cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED => false ,
217222 } )
218223 }
219224 }
@@ -233,9 +238,9 @@ impl CublasContext {
233238 /// ```
234239 pub fn set_math_mode ( & self , math_mode : MathMode ) -> Result < ( ) > {
235240 unsafe {
236- Ok ( sys :: v2 :: cublasSetMathMode (
241+ Ok ( cublas_sys :: cublasSetMathMode (
237242 self . raw ,
238- mem:: transmute :: < u32 , cublas_sys:: v2 :: cublasMath_t > ( math_mode. bits ( ) ) ,
243+ mem:: transmute :: < u32 , cublas_sys:: cublasMath_t > ( math_mode. bits ( ) ) ,
239244 )
240245 . to_result ( ) ?)
241246 }
@@ -258,7 +263,7 @@ impl CublasContext {
258263 pub fn get_math_mode ( & self ) -> Result < MathMode > {
259264 let mut mode = MaybeUninit :: uninit ( ) ;
260265 unsafe {
261- sys :: v2 :: cublasGetMathMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
266+ cublas_sys :: cublasGetMathMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
262267 Ok ( MathMode :: from_bits ( mode. assume_init ( ) as u32 )
263268 . expect ( "Invalid MathMode from cuBLAS" ) )
264269 }
@@ -298,7 +303,7 @@ impl CublasContext {
298303 let path = log_file_name. map ( |p| CString :: new ( p) . expect ( "nul in log_file_name" ) ) ;
299304 let path_ptr = path. map_or ( ptr:: null ( ) , |s| s. as_ptr ( ) ) ;
300305
301- sys :: v2 :: cublasLoggerConfigure (
306+ cublas_sys :: cublasLoggerConfigure (
302307 enable as i32 ,
303308 log_to_stdout as i32 ,
304309 log_to_stderr as i32 ,
@@ -315,7 +320,7 @@ impl CublasContext {
315320 ///
316321 /// The callback must not panic and unwind.
317322 pub unsafe fn set_logger_callback ( callback : Option < unsafe extern "C" fn ( * const c_char ) > ) {
318- sys :: v2 :: cublasSetLoggerCallback ( callback)
323+ cublas_sys :: cublasSetLoggerCallback ( callback)
319324 . to_result ( )
320325 . unwrap ( ) ;
321326 }
@@ -324,7 +329,7 @@ impl CublasContext {
324329 pub fn get_logger_callback ( ) -> Option < unsafe extern "C" fn ( * const c_char ) > {
325330 let mut cb = MaybeUninit :: uninit ( ) ;
326331 unsafe {
327- sys :: v2 :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) )
332+ cublas_sys :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) )
328333 . to_result ( )
329334 . unwrap ( ) ;
330335 cb. assume_init ( )
@@ -335,7 +340,7 @@ impl CublasContext {
335340impl Drop for CublasContext {
336341 fn drop ( & mut self ) {
337342 unsafe {
338- sys :: v2 :: cublasDestroy_v2 ( self . raw ) ;
343+ cublas_sys :: cublasDestroy_v2 ( self . raw ) ;
339344 }
340345 }
341346}
0 commit comments