@@ -339,11 +339,69 @@ impl JinaBertEncoder {
339339 }
340340}
341341
342+ pub trait ClassificationHead {
343+ fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > ;
344+ }
345+
346+ pub struct JinaBertClassificationHead {
347+ pooler : Option < Linear > ,
348+ output : Linear ,
349+ span : tracing:: Span ,
350+ }
351+
352+ impl JinaBertClassificationHead {
353+ pub ( crate ) fn load ( vb : VarBuilder , config : & BertConfig ) -> Result < Self > {
354+ let n_classes = match & config. id2label {
355+ None => candle:: bail!( "`id2label` must be set for classifier models" ) ,
356+ Some ( id2label) => id2label. len ( ) ,
357+ } ;
358+
359+ let pooler = if let Ok ( pooler_weight) = vb
360+ . pp ( "bert.pooler.dense" )
361+ . get ( ( config. hidden_size , config. hidden_size ) , "weight" )
362+ {
363+ let pooler_bias = vb. pp ( "bert.pooler.dense" ) . get ( config. hidden_size , "bias" ) ?;
364+ Some ( Linear :: new ( pooler_weight, Some ( pooler_bias) , None ) )
365+ } else {
366+ None
367+ } ;
368+
369+ let output_weight = vb
370+ . pp ( "classifier" )
371+ . get ( ( n_classes, config. hidden_size ) , "weight" ) ?;
372+ let output_bias = vb. pp ( "classifier" ) . get ( n_classes, "bias" ) ?;
373+ let output = Linear :: new ( output_weight, Some ( output_bias) , None ) ;
374+
375+ Ok ( Self {
376+ pooler,
377+ output,
378+ span : tracing:: span!( tracing:: Level :: TRACE , "classifier" ) ,
379+ } )
380+ }
381+ }
382+
383+ impl ClassificationHead for JinaBertClassificationHead {
384+ fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
385+ let _enter = self . span . enter ( ) ;
386+
387+ let mut hidden_states = hidden_states. unsqueeze ( 1 ) ?;
388+ if let Some ( pooler) = self . pooler . as_ref ( ) {
389+ hidden_states = pooler. forward ( & hidden_states) ?;
390+ hidden_states = hidden_states. tanh ( ) ?;
391+ }
392+
393+ let hidden_states = self . output . forward ( & hidden_states) ?;
394+ let hidden_states = hidden_states. squeeze ( 1 ) ?;
395+ Ok ( hidden_states)
396+ }
397+ }
398+
342399pub struct JinaBertModel {
343400 embeddings : JinaEmbeddings ,
344401 encoder : JinaBertEncoder ,
345402 pool : Pool ,
346403 alibi : Option < Tensor > ,
404+ classifier : Option < Box < dyn ClassificationHead + Send > > ,
347405
348406 num_attention_heads : usize ,
349407
@@ -366,9 +424,12 @@ impl JinaBertModel {
366424 _ => candle:: bail!( "not supported" ) ,
367425 } ;
368426
369- let pool = match model_type {
427+ let ( pool, classifier ) = match model_type {
370428 ModelType :: Classifier => {
371- candle:: bail!( "`classifier` model type is not supported for Jina" )
429+ let pool = Pool :: Cls ;
430+ let classifier: Box < dyn ClassificationHead + Send > =
431+ Box :: new ( JinaBertClassificationHead :: load ( vb. clone ( ) , config) ?) ;
432+ ( pool, Some ( classifier) )
372433 }
373434 ModelType :: Embedding ( pool) => {
374435 if pool == Pool :: Splade {
@@ -377,7 +438,7 @@ impl JinaBertModel {
377438 if pool == Pool :: LastToken {
378439 candle:: bail!( "`last_token` is not supported for Jina" ) ;
379440 }
380- pool
441+ ( pool, None )
381442 }
382443 } ;
383444
@@ -403,6 +464,7 @@ impl JinaBertModel {
403464 encoder,
404465 pool,
405466 alibi,
467+ classifier,
406468 num_attention_heads : config. num_attention_heads ,
407469 device : vb. device ( ) . clone ( ) ,
408470 dtype : vb. dtype ( ) ,
@@ -667,7 +729,20 @@ impl Model for JinaBertModel {
667729 fn is_padded ( & self ) -> bool {
668730 true
669731 }
732+
670733 fn embed ( & self , batch : Batch ) -> Result < ( Option < Tensor > , Option < Tensor > ) > {
671734 self . forward ( batch)
672735 }
736+
737+ fn predict ( & self , batch : Batch ) -> Result < Tensor > {
738+ match & self . classifier {
739+ None => candle:: bail!( "`predict` is not implemented for this model" ) ,
740+ Some ( classifier) => {
741+ let ( pooled_embeddings, _raw_embeddings) = self . forward ( batch) ?;
742+ let pooled_embeddings =
743+ pooled_embeddings. expect ( "pooled_embeddings is empty. This is a bug." ) ;
744+ classifier. forward ( & pooled_embeddings)
745+ }
746+ }
747+ }
673748}
0 commit comments