@@ -12,12 +12,12 @@ use crate::compute_cap::{
1212} ;
1313use crate :: models:: {
1414 BertConfig , BertModel , DistilBertConfig , DistilBertModel , JinaBertModel , JinaCodeBertModel ,
15- Model , NomicBertModel , NomicConfig ,
15+ MistralConfig , Model , NomicBertModel , NomicConfig ,
1616} ;
1717#[ cfg( feature = "cuda" ) ]
1818use crate :: models:: {
1919 FlashBertModel , FlashDistilBertModel , FlashJinaBertModel , FlashJinaCodeBertModel ,
20- FlashNomicBertModel ,
20+ FlashMistralModel , FlashNomicBertModel ,
2121} ;
2222use anyhow:: Context ;
2323use candle:: { DType , Device } ;
@@ -56,6 +56,7 @@ enum Config {
5656 DistilBert ( DistilBertConfig ) ,
5757 #[ serde( rename( deserialize = "nomic_bert" ) ) ]
5858 NomicBert ( NomicConfig ) ,
59+ Mistral ( MistralConfig ) ,
5960}
6061
6162pub struct CandleBackend {
@@ -69,6 +70,54 @@ impl CandleBackend {
6970 dtype : String ,
7071 model_type : ModelType ,
7172 ) -> Result < Self , BackendError > {
73+ // Default files
74+ let default_safetensors = model_path. join ( "model.safetensors" ) ;
75+ let default_pytorch = model_path. join ( "pytorch_model.bin" ) ;
76+
77+ // Single Files
78+ let model_files = if default_safetensors. exists ( ) {
79+ vec ! [ default_safetensors]
80+ } else if default_pytorch. exists ( ) {
81+ vec ! [ default_pytorch]
82+ }
83+ // Sharded weights
84+ else {
85+ // Get index file
86+ let index_file = model_path. join ( "model.safetensors.index.json" ) ;
87+
88+ // Parse file
89+ let index_file_string: String = std:: fs:: read_to_string ( & index_file)
90+ . map_err ( |err| BackendError :: Start ( err. to_string ( ) ) ) ?;
91+ let json: serde_json:: Value = serde_json:: from_str ( & index_file_string)
92+ . map_err ( |err| BackendError :: Start ( err. to_string ( ) ) ) ?;
93+
94+ let weight_map = match json. get ( "weight_map" ) {
95+ None => {
96+ return Err ( BackendError :: Start ( format ! (
97+ "no weight map in {index_file:?}"
98+ ) ) ) ;
99+ }
100+ Some ( serde_json:: Value :: Object ( map) ) => map,
101+ Some ( _) => {
102+ return Err ( BackendError :: Start ( format ! (
103+ "weight map in {index_file:?} is not a map"
104+ ) ) ) ;
105+ }
106+ } ;
107+ let mut safetensors_files = std:: collections:: HashSet :: new ( ) ;
108+ for value in weight_map. values ( ) {
109+ if let Some ( file) = value. as_str ( ) {
110+ safetensors_files. insert ( file. to_string ( ) ) ;
111+ }
112+ }
113+
114+ // Collect paths
115+ safetensors_files
116+ . iter ( )
117+ . map ( |n| model_path. join ( n) )
118+ . collect ( )
119+ } ;
120+
72121 // Load config
73122 let config: String = std:: fs:: read_to_string ( model_path. join ( "config.json" ) )
74123 . context ( "Unable to read config file" )
@@ -115,17 +164,10 @@ impl CandleBackend {
115164 ) ) )
116165 } ?;
117166
118- let safetensors_path = model_path. join ( "model.safetensors" ) ;
119- let vb = if safetensors_path. exists ( ) {
120- unsafe {
121- VarBuilder :: from_mmaped_safetensors (
122- & [ model_path. join ( "model.safetensors" ) ] ,
123- dtype,
124- & device,
125- )
126- }
167+ let vb = if model_files. len ( ) == 1 && model_files[ 0 ] . extension ( ) . unwrap ( ) == "bin" {
168+ VarBuilder :: from_pth ( & model_files[ 0 ] , dtype, & device)
127169 } else {
128- VarBuilder :: from_pth ( model_path . join ( "pytorch_model.bin" ) , dtype, & device)
170+ unsafe { VarBuilder :: from_mmaped_safetensors ( & model_files , dtype, & device) }
129171 }
130172 . s ( ) ?;
131173
@@ -136,7 +178,7 @@ impl CandleBackend {
136178 ) ) ,
137179 ( Config :: Bert ( config) , Device :: Cpu | Device :: Metal ( _) ) => match config {
138180 BertConfigWrapper :: JinaBert ( config) => {
139- tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
181+ tracing:: info!( "Starting JinaBert model on {:?}" , device) ;
140182 Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
141183 }
142184 BertConfigWrapper :: JinaCodeBert ( config) => {
@@ -160,15 +202,19 @@ impl CandleBackend {
160202 ) )
161203 }
162204 ( Config :: DistilBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
163- tracing:: info!( "Starting DistilBertModel model on {:?}" , device) ;
205+ tracing:: info!( "Starting DistilBert model on {:?}" , device) ;
164206 Ok ( Box :: new (
165207 DistilBertModel :: load ( vb, & config, model_type) . s ( ) ?,
166208 ) )
167209 }
168210 ( Config :: NomicBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
169- tracing:: info!( "Starting NomicBertModel model on {:?}" , device) ;
211+ tracing:: info!( "Starting NomicBert model on {:?}" , device) ;
170212 Ok ( Box :: new ( NomicBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
171213 }
214+ ( Config :: Mistral ( _) , Device :: Cpu | Device :: Metal ( _) ) => Err ( BackendError :: Start (
215+ "Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
216+ . to_string ( ) ,
217+ ) ) ,
172218 #[ cfg( feature = "cuda" ) ]
173219 ( Config :: Bert ( config) , Device :: Cuda ( _) ) => {
174220 if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
@@ -198,7 +244,7 @@ impl CandleBackend {
198244 } else {
199245 match config {
200246 BertConfigWrapper :: JinaBert ( config) => {
201- tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
247+ tracing:: info!( "Starting JinaBert model on {:?}" , device) ;
202248 Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
203249 }
204250 BertConfigWrapper :: JinaCodeBert ( config) => {
@@ -245,7 +291,7 @@ impl CandleBackend {
245291 . to_lowercase ( )
246292 == "true"
247293 {
248- tracing:: info!( "Starting FlashDistilBertModel model on {:?}" , device) ;
294+ tracing:: info!( "Starting FlashDistilBert model on {:?}" , device) ;
249295 Ok ( Box :: new (
250296 FlashDistilBertModel :: load ( vb, & config, model_type) . s ( ) ?,
251297 ) )
@@ -265,15 +311,28 @@ impl CandleBackend {
265311 . to_lowercase ( )
266312 == "true"
267313 {
268- tracing:: info!( "Starting FlashNomicBertModel model on {:?}" , device) ;
314+ tracing:: info!( "Starting FlashNomicBert model on {:?}" , device) ;
269315 Ok ( Box :: new (
270316 FlashNomicBertModel :: load ( vb, & config, model_type) . s ( ) ?,
271317 ) )
272318 } else {
273- tracing:: info!( "Starting NomicBertModel model on {:?}" , device) ;
319+ tracing:: info!( "Starting NomicBert model on {:?}" , device) ;
274320 Ok ( Box :: new ( NomicBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
275321 }
276322 }
323+ #[ cfg( feature = "cuda" ) ]
324+ ( Config :: Mistral ( config) , Device :: Cuda ( _) ) => {
325+ if dtype != DType :: F16
326+ || !cfg ! ( feature = "flash-attn" )
327+ || get_runtime_compute_cap ( ) . unwrap ( ) < 80
328+ {
329+ return Err ( BackendError :: Start ( "Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
330+ }
331+ tracing:: info!( "Starting FlashMistral model on {:?}" , device) ;
332+ Ok ( Box :: new (
333+ FlashMistralModel :: load ( vb, & config, model_type) . s ( ) ?,
334+ ) )
335+ }
277336 } ;
278337
279338 Ok ( Self {
0 commit comments