1- use crate :: { error:: * , sys} ;
2- use cust:: stream:: Stream ;
31use std:: ffi:: CString ;
42use std:: mem:: { self , MaybeUninit } ;
53use std:: os:: raw:: c_char;
64use std:: ptr;
75
8- type Result < T , E = Error > = std:: result:: Result < T , E > ;
6+ use cust:: stream:: Stream ;
7+ use cust_raw:: cublas_sys;
8+ use cust_raw:: driver_sys;
9+
10+ use super :: error:: DropResult ;
11+ use super :: error:: ToResult as _;
12+
13+ type Result < T , E = super :: error:: Error > = std:: result:: Result < T , E > ;
914
1015bitflags:: bitflags! {
1116 /// Configures precision levels for the math in cuBLAS.
12- #[ derive( Default ) ]
17+ #[ derive( Debug , Default , Clone , Copy , PartialEq , Eq , Hash ) ]
1318 pub struct MathMode : u32 {
1419 /// Highest performance mode which uses compute and intermediate storage precisions
1520 /// with at least the same number of mantissa and exponent bits as requested. Will
@@ -68,7 +73,7 @@ bitflags::bitflags! {
6873/// - [Matrix Multiplication <span style="float:right;">`gemm`</span>](CublasContext::gemm)
6974#[ derive( Debug ) ]
7075pub struct CublasContext {
71- pub ( crate ) raw : sys :: v2 :: cublasHandle_t ,
76+ pub ( crate ) raw : cublas_sys :: cublasHandle_t ,
7277}
7378
7479impl CublasContext {
@@ -87,10 +92,10 @@ impl CublasContext {
8792 pub fn new ( ) -> Result < Self > {
8893 let mut raw = MaybeUninit :: uninit ( ) ;
8994 unsafe {
90- sys :: v2 :: cublasCreate_v2 ( raw. as_mut_ptr ( ) ) . to_result ( ) ?;
91- sys :: v2 :: cublasSetPointerMode_v2 (
95+ cublas_sys :: cublasCreate_v2 ( raw. as_mut_ptr ( ) ) . to_result ( ) ?;
96+ cublas_sys :: cublasSetPointerMode_v2 (
9297 raw. assume_init ( ) ,
93- sys :: v2 :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE ,
98+ cublas_sys :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE ,
9499 )
95100 . to_result ( ) ?;
96101 Ok ( Self {
@@ -107,7 +112,7 @@ impl CublasContext {
107112
108113 unsafe {
109114 let inner = mem:: replace ( & mut ctx. raw , ptr:: null_mut ( ) ) ;
110- match sys :: v2 :: cublasDestroy_v2 ( inner) . to_result ( ) {
115+ match cublas_sys :: cublasDestroy_v2 ( inner) . to_result ( ) {
111116 Ok ( ( ) ) => {
112117 mem:: forget ( ctx) ;
113118 Ok ( ( ) )
@@ -122,7 +127,7 @@ impl CublasContext {
122127 let mut raw = MaybeUninit :: < u32 > :: uninit ( ) ;
123128 unsafe {
124129 // getVersion can't fail
125- sys :: v2 :: cublasGetVersion_v2 ( self . raw , raw. as_mut_ptr ( ) . cast ( ) )
130+ cublas_sys :: cublasGetVersion_v2 ( self . raw , raw. as_mut_ptr ( ) . cast ( ) )
126131 . to_result ( )
127132 . unwrap ( ) ;
128133
@@ -140,17 +145,17 @@ impl CublasContext {
140145 ) -> Result < T > {
141146 unsafe {
142147 // cudaStream_t is the same as CUstream
143- sys :: v2 :: cublasSetStream_v2 (
148+ cublas_sys :: cublasSetStream_v2 (
144149 self . raw ,
145- mem:: transmute :: < * mut cust :: sys :: CUstream_st , * mut cublas_sys:: v2 :: CUstream_st > (
150+ mem:: transmute :: < * mut driver_sys :: CUstream_st , * mut cublas_sys:: CUstream_st > (
146151 stream. as_inner ( ) ,
147152 ) ,
148153 )
149154 . to_result ( ) ?;
150155 let res = func ( self ) ?;
151156 // reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to
152157 // execute a raw sys function with the context's handle.
153- sys :: v2 :: cublasSetStream_v2 ( self . raw , ptr:: null_mut ( ) ) . to_result ( ) ?;
158+ cublas_sys :: cublasSetStream_v2 ( self . raw , ptr:: null_mut ( ) ) . to_result ( ) ?;
154159 Ok ( res)
155160 }
156161 }
@@ -180,12 +185,12 @@ impl CublasContext {
180185 /// ```
181186 pub fn set_atomics_mode ( & self , allowed : bool ) -> Result < ( ) > {
182187 unsafe {
183- Ok ( sys :: v2 :: cublasSetAtomicsMode (
188+ Ok ( cublas_sys :: cublasSetAtomicsMode (
184189 self . raw ,
185190 if allowed {
186- sys :: v2 :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED
191+ cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED
187192 } else {
188- sys :: v2 :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED
193+ cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED
189194 } ,
190195 )
191196 . to_result ( ) ?)
@@ -210,10 +215,11 @@ impl CublasContext {
210215 pub fn get_atomics_mode ( & self ) -> Result < bool > {
211216 let mut mode = MaybeUninit :: uninit ( ) ;
212217 unsafe {
213- sys :: v2 :: cublasGetAtomicsMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
218+ cublas_sys :: cublasGetAtomicsMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
214219 Ok ( match mode. assume_init ( ) {
215- sys:: v2:: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED => true ,
216- sys:: v2:: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED => false ,
220+ cublas_sys:: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED => true ,
221+ cublas_sys:: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED => false ,
222+ _ => false ,
217223 } )
218224 }
219225 }
@@ -233,9 +239,9 @@ impl CublasContext {
233239 /// ```
234240 pub fn set_math_mode ( & self , math_mode : MathMode ) -> Result < ( ) > {
235241 unsafe {
236- Ok ( sys :: v2 :: cublasSetMathMode (
242+ Ok ( cublas_sys :: cublasSetMathMode (
237243 self . raw ,
238- mem:: transmute :: < u32 , cublas_sys:: v2 :: cublasMath_t > ( math_mode. bits ( ) ) ,
244+ mem:: transmute :: < u32 , cublas_sys:: cublasMath_t > ( math_mode. bits ( ) ) ,
239245 )
240246 . to_result ( ) ?)
241247 }
@@ -258,7 +264,7 @@ impl CublasContext {
258264 pub fn get_math_mode ( & self ) -> Result < MathMode > {
259265 let mut mode = MaybeUninit :: uninit ( ) ;
260266 unsafe {
261- sys :: v2 :: cublasGetMathMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
267+ cublas_sys :: cublasGetMathMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
262268 Ok ( MathMode :: from_bits ( mode. assume_init ( ) as u32 )
263269 . expect ( "Invalid MathMode from cuBLAS" ) )
264270 }
@@ -298,7 +304,7 @@ impl CublasContext {
298304 let path = log_file_name. map ( |p| CString :: new ( p) . expect ( "nul in log_file_name" ) ) ;
299305 let path_ptr = path. map_or ( ptr:: null ( ) , |s| s. as_ptr ( ) ) ;
300306
301- sys :: v2 :: cublasLoggerConfigure (
307+ cublas_sys :: cublasLoggerConfigure (
302308 enable as i32 ,
303309 log_to_stdout as i32 ,
304310 log_to_stderr as i32 ,
@@ -315,7 +321,7 @@ impl CublasContext {
315321 ///
316322 /// The callback must not panic and unwind.
317323 pub unsafe fn set_logger_callback ( callback : Option < unsafe extern "C" fn ( * const c_char ) > ) {
318- sys :: v2 :: cublasSetLoggerCallback ( callback)
324+ cublas_sys :: cublasSetLoggerCallback ( callback)
319325 . to_result ( )
320326 . unwrap ( ) ;
321327 }
@@ -324,7 +330,7 @@ impl CublasContext {
324330 pub fn get_logger_callback ( ) -> Option < unsafe extern "C" fn ( * const c_char ) > {
325331 let mut cb = MaybeUninit :: uninit ( ) ;
326332 unsafe {
327- sys :: v2 :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) )
333+ cublas_sys :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) )
328334 . to_result ( )
329335 . unwrap ( ) ;
330336 cb. assume_init ( )
@@ -335,7 +341,7 @@ impl CublasContext {
335341impl Drop for CublasContext {
336342 fn drop ( & mut self ) {
337343 unsafe {
338- sys :: v2 :: cublasDestroy_v2 ( self . raw ) ;
344+ cublas_sys :: cublasDestroy_v2 ( self . raw ) ;
339345 }
340346 }
341347}
0 commit comments