4
4
5
5
from transformers import AutoConfig , AutoModelForCausalLM , AutoModel , AutoModelForVision2Seq , AutoTokenizer
6
6
7
+ from .reranking_evaluator import DEF_TOP_K , DEF_MAX_LENGTH
7
8
from .utils import mock_torch_cuda_is_available , mock_AwqQuantizer_validate_environment
8
9
9
10
@@ -20,7 +21,7 @@ def __init__(self, model, model_dir, model_type):
20
21
self .model = model
21
22
self .model_type = model_type
22
23
23
- if model_type == "text" or model_type == "visual-text" :
24
+ if model_type in [ "text" , "visual-text" , "text-reranking" ] :
24
25
try :
25
26
self .config = AutoConfig .from_pretrained (model_dir )
26
27
except Exception :
@@ -428,6 +429,68 @@ def load_inpainting_model(
428
429
return model
429
430
430
431
432
+ def load_reranking_genai_pipeline (model_dir , device = "CPU" , ov_config = None ):
433
+ try :
434
+ import openvino_genai
435
+ except ImportError as e :
436
+ logger .error ("Failed to import openvino_genai package. Please install it. Details:\n " , e )
437
+ exit (- 1 )
438
+
439
+ logger .info ("Using OpenVINO GenAI TextRerankPipeline API" )
440
+
441
+ config = openvino_genai .TextRerankPipeline .Config ()
442
+ config .top_n = DEF_TOP_K
443
+ config .max_length = DEF_MAX_LENGTH
444
+
445
+ pipeline = openvino_genai .TextRerankPipeline (model_dir , device .upper (), config , ** ov_config )
446
+
447
+ return GenAIModelWrapper (
448
+ pipeline ,
449
+ model_dir ,
450
+ "text-reranking"
451
+ )
452
+
453
+
454
+ def load_reranking_model (model_id , device = "CPU" , ov_config = None , use_hf = False , use_genai = False ):
455
+ if use_hf :
456
+ logger .info ("Using HF Transformers API" )
457
+ if 'qwen3' in model_id .lower ():
458
+ from transformers import AutoModelForCausalLM
459
+ model = AutoModelForCausalLM .from_pretrained (model_id , trust_remote_code = True )
460
+ else :
461
+ from transformers import AutoModelForSequenceClassification
462
+ model = AutoModelForSequenceClassification .from_pretrained (model_id , trust_remote_code = True )
463
+ elif use_genai :
464
+ logger .info ("Using OpenVINO GenAI API" )
465
+ model = load_reranking_genai_pipeline (model_id , device , ov_config )
466
+ else :
467
+ logger .info ("Using Optimum API" )
468
+ model_cls = None
469
+ if 'qwen3' in model_id .lower ():
470
+ from optimum .intel .openvino import OVModelForCausalLM
471
+ model_cls = OVModelForCausalLM
472
+ else :
473
+ from optimum .intel .openvino import OVModelForSequenceClassification
474
+ model_cls = OVModelForSequenceClassification
475
+
476
+ try :
477
+ model = model_cls .from_pretrained (
478
+ model_id , device = device , ov_config = ov_config , safety_checker = None ,
479
+ )
480
+ except ValueError as e :
481
+ logger .error ("Failed to load reranking pipeline. Details:\n " , e )
482
+ model = model_cls .from_pretrained (
483
+ model_id ,
484
+ trust_remote_code = True ,
485
+ use_cache = False ,
486
+ device = device ,
487
+ ov_config = ov_config ,
488
+ safety_checker = None
489
+ )
490
+
491
+ return model
492
+
493
+
431
494
def load_model (
432
495
model_type , model_id , device = "CPU" , ov_config = None , use_hf = False , use_genai = False , use_llamacpp = False , ** kwargs
433
496
):
@@ -452,5 +515,7 @@ def load_model(
452
515
return load_imagetext2image_model (model_id , device , ov_options , use_hf , use_genai )
453
516
elif model_type == "image-inpainting" :
454
517
return load_inpainting_model (model_id , device , ov_options , use_hf , use_genai )
518
+ elif model_type == "text-reranking" :
519
+ return load_reranking_model (model_id , device , ov_options , use_hf , use_genai )
455
520
else :
456
521
raise ValueError (f"Unsupported model type: { model_type } " )
0 commit comments