@@ -423,6 +423,10 @@ impl CandleBackend {
423423 if dtype != DType :: F16
424424 || !cfg ! ( feature = "flash-attn" )
425425 || get_runtime_compute_cap ( ) . unwrap ( ) < 80
426+ || & std:: env:: var ( "USE_FLASH_ATTENTION" )
427+ . unwrap_or ( "True" . to_string ( ) )
428+ . to_lowercase ( )
429+ != "true"
426430 {
427431 return Err ( BackendError :: Start ( "Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
428432 }
@@ -435,6 +439,10 @@ impl CandleBackend {
435439 ( Config :: Gte ( config) , Device :: Cuda ( _) ) => {
436440 if dtype != DType :: F16
437441 || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
442+ || & std:: env:: var ( "USE_FLASH_ATTENTION" )
443+ . unwrap_or ( "True" . to_string ( ) )
444+ . to_lowercase ( )
445+ != "true"
438446 {
439447 tracing:: info!( "Starting GTE model on {:?}" , device) ;
440448 Ok ( Box :: new ( GTEModel :: load ( vb, & config, model_type) . s ( ) ?) )
@@ -447,6 +455,10 @@ impl CandleBackend {
447455 ( Config :: Qwen2 ( config) , Device :: Cuda ( _) ) => {
448456 if dtype != DType :: F16
449457 || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
458+ || & std:: env:: var ( "USE_FLASH_ATTENTION" )
459+ . unwrap_or ( "True" . to_string ( ) )
460+ . to_lowercase ( )
461+ != "true"
450462 {
451463 return Err ( BackendError :: Start ( "Qwen2 is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
452464 }
@@ -459,6 +471,10 @@ impl CandleBackend {
459471 ( Config :: Qwen3 ( config) , Device :: Cuda ( _) ) => {
460472 if dtype != DType :: F16
461473 || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
474+ || & std:: env:: var ( "USE_FLASH_ATTENTION" )
475+ . unwrap_or ( "True" . to_string ( ) )
476+ . to_lowercase ( )
477+ != "true"
462478 {
463479 tracing:: info!( "Starting Qwen3 model on {:?}" , device) ;
464480 Ok ( Box :: new ( Qwen3Model :: load ( vb, & config, model_type) . s ( ) ?) )
0 commit comments