@@ -464,3 +464,50 @@ def load(checkpoint_dir=None, model=None, layer_wise=False, history_cfg=None, **
464
464
assert len (mismatch_log .unexpected_keys ) == 0 , "Loading state_dict failed: {}" .format (mismatch_log )
465
465
util .get_embedding_contiguous (model )
466
466
return model
467
+
468
+
469
+ def recover_model_from_json (model , json_file_path , example_inputs ):
470
+ """Recover ipex model from JSON file.
471
+
472
+ Args:
473
+ model (object): fp32 model need to do quantization.
474
+ json_file_path (json): configuration JSON file for ipex.
475
+ example_inputs (tuple or torch.Tensor or dict): example inputs that will be passed to the ipex function.
476
+
477
+ Returns:
478
+ (object): quantized model
479
+ """
480
+ from ..utils .utility import LazyImport
481
+
482
+ ipex = LazyImport ("intel_extension_for_pytorch" )
483
+ from torch .ao .quantization .observer import MinMaxObserver
484
+
485
+ qconfig = ipex .quantization .get_smooth_quant_qconfig_mapping (alpha = 0.5 , act_observer = MinMaxObserver ())
486
+ if isinstance (example_inputs , dict ):
487
+ model = ipex .quantization .prepare (model , qconfig , example_kwarg_inputs = example_inputs , inplace = True )
488
+ else :
489
+ model = ipex .quantization .prepare (model , qconfig , example_inputs = example_inputs , inplace = True )
490
+ model .load_qconf_summary (qconf_summary = json_file_path )
491
+ model = ipex .quantization .convert (model , inplace = True )
492
+ with torch .no_grad ():
493
+ try :
494
+ if isinstance (example_inputs , dict ):
495
+ # pylint: disable=E1120,E1123
496
+ model = torch .jit .trace (model , example_kwarg_inputs = example_inputs )
497
+ else :
498
+ model = torch .jit .trace (model , example_inputs )
499
+ model = torch .jit .freeze (model .eval ())
500
+ except :
501
+ if isinstance (example_inputs , dict ):
502
+ # pylint: disable=E1120,E1123
503
+ model = torch .jit .trace (model , example_kwarg_inputs = example_inputs , strict = False , check_trace = False )
504
+ else :
505
+ model = torch .jit .trace (model , example_inputs , strict = False )
506
+ model = torch .jit .freeze (model .eval ())
507
+ if isinstance (example_inputs , dict ):
508
+ model (** example_inputs )
509
+ model (** example_inputs )
510
+ else :
511
+ model (example_inputs )
512
+ model (example_inputs )
513
+ return model
0 commit comments