diff --git a/quantization/image_classification/migraphx/resnet50/e2e_migraphx_resnet_example.py b/quantization/image_classification/migraphx/resnet50/e2e_migraphx_resnet_example.py index 52dd29802..f92969126 100644 --- a/quantization/image_classification/migraphx/resnet50/e2e_migraphx_resnet_example.py +++ b/quantization/image_classification/migraphx/resnet50/e2e_migraphx_resnet_example.py @@ -10,9 +10,121 @@ import onnxruntime from onnxruntime.quantization import CalibrationDataReader, create_calibrator, write_calibration_table +def custom_write_calibration_table(calibration_cache, filename): + """ + Helper function to write calibration table to files. + """ + + import json + import logging + import flatbuffers + import numpy as np + + import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue + import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable + from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData + + logging.info(f"calibration cache: {calibration_cache}") + + class MyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (TensorData, TensorsData)): + return obj.to_dict() + if isinstance(obj, TensorDataWrapper): + return obj.data_dict + if isinstance(obj, np.ndarray): + return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"} + if isinstance(obj, CalibrationMethod): + return {"CLS": obj.__class__.__name__, "value": str(obj)} + return json.JSONEncoder.default(self, obj) + + json_data = json.dumps(calibration_cache, cls=MyEncoder) + + with open(filename, "w") as file: + file.write(json_data) # use `json.loads` to do the reverse + + # Serialize data using FlatBuffers + zero = np.array(0) + builder = flatbuffers.Builder(1024) + key_value_list = [] + + for key in sorted(calibration_cache.keys()): + values = calibration_cache[key] + d_values = values.to_dict() + + highest = d_values.get("highest", zero) + lowest = d_values.get("lowest", zero) + + highest_val = highest.item() if hasattr(highest, "item") else float(highest) + lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest) + + floats = [float(highest_val), float(lowest_val)] + + value = str(max(floats)) + + flat_key = builder.CreateString(key) + flat_value = builder.CreateString(value) + + KeyValue.KeyValueStart(builder) + KeyValue.KeyValueAddKey(builder, flat_key) + KeyValue.KeyValueAddValue(builder, flat_value) + key_value = KeyValue.KeyValueEnd(builder) + + key_value_list.append(key_value) + + + TrtTable.TrtTableStartDictVector(builder, len(key_value_list)) + for key_value in key_value_list: + builder.PrependUOffsetTRelative(key_value) + main_dict = builder.EndVector() + + TrtTable.TrtTableStart(builder) + TrtTable.TrtTableAddDict(builder, main_dict) + cal_table = TrtTable.TrtTableEnd(builder) + + builder.Finish(cal_table) + buf = builder.Output() + + with open(filename, "wb") as file: + file.write(buf) + + # Deserialize data (for validation) + if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"): + cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0) + dict_len = cal_table.DictLength() + for i in range(dict_len): + key_value = cal_table.Dict(i) + logging.info(key_value.Key()) + logging.info(key_value.Value()) + + # write plain text + with open(filename + ".cache", "w") as file: + for key in sorted(calibration_cache.keys()): + values = calibration_cache[key] + d_values = values.to_dict() + highest = d_values.get("highest", zero) + lowest = d_values.get("lowest", zero) + + highest_val = highest.item() if hasattr(highest, "item") else float(highest) + lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest) + + floats = [float(highest_val), float(lowest_val)] + + value = key + " " + str(max(floats)) + file.write(value) + file.write("\n") + + def parse_input_args(): parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + required=False, + default='./resnet50-v2-7.onnx', + help='Target DIR for model. Default is ./resnet50-v2-7.onnx', + ) + parser.add_argument( "--fp16", action="store_true", @@ -29,6 +141,14 @@ def parse_input_args(): help='Perform no quantization', ) + parser.add_argument( + "--fp8", + action="store_true", + required=False, + default=False, + help='Perform fp8 quantizaton instead of int8', + ) + parser.add_argument( "--image_dir", required=False, @@ -48,6 +168,29 @@ def parse_input_args(): help='Size of images for calibration', type=int) + parser.add_argument( + "--exhaustive_tune", + action="store_true", + required=False, + default=False, + help='Enable MIGraphX Exhaustive tune before compile. Default False', + ) + + parser.add_argument( + "--cache", + action="store_true", + required=False, + default=True, + help='cache the compiled model between runs. Saves quantization and compile time. Default true', + ) + + parser.add_argument( + "--cache_name", + required=False, + default="./cached_model.mxr", + help='Name and path of the compiled model cache. Default: ./cached_model.mxr', + ) + return parser.parse_args() class ImageNetDataReader(CalibrationDataReader): @@ -255,6 +398,7 @@ class ImageClassificationEvaluator: def __init__(self, model_path, synset_id, + flags, data_reader: CalibrationDataReader, providers=["MIGraphXExecutionProvider"]): ''' @@ -276,10 +420,21 @@ def get_result(self): def predict(self): sess_options = onnxruntime.SessionOptions() - sess_options.log_severity_level = 0 - sess_options.log_verbosity_level = 0 + sess_options.log_severity_level = 2 + sess_options.log_verbosity_level = 2 sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options, providers=self.providers) + session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options, + providers=[("MIGraphXExecutionProvider", + {"migraphx_fp8_enable": flags.fp8 and not flags.fp32, + "migraphx_int8_enable": not (flags.fp8 or flags.fp32), + "migraphx_fp16_enable": flags.fp16 and not flags.fp32, + "migraphx_int8_calibration_table_name": flags.calibration_table, + "migraphx_use_native_calibration_table": flags.native_calibration_table, + "migraphx_save_compiled_model": flags.cache, + "migraphx_save_model_path": flags.cache_name, + "migraphx_load_compiled_model": flags.cache, + "migraphx_load_model_path": flags.cache_name, + "migraphx_exhaustive_tune": flags.exhaustive_tune})]) inference_outputs_list = [] while True: @@ -362,21 +517,31 @@ def get_dataset_size(dataset_path, calibration_dataset_size): flags = parse_input_args() # Dataset settings - model_path = "./resnet50-v2-7.onnx" + model_path = flags.model ilsvrc2012_dataset_path = flags.image_dir augmented_model_path = "./augmented_model.onnx" batch_size = flags.batch calibration_dataset_size = 0 if flags.fp32 else flags.cal_size # Size of dataset for calibration + precision="" + + if not (flags.fp8 or flags.fp32): + precision = precision + "_int8" + + if flags.fp8 and not flags.fp32: + precision = precision + "_fp8" + + if flags.fp16 and not flags.fp32: + precision = "_fp16" + precision + calibration_table_generation_enable = False if not flags.fp32: - # INT8 calibration setting calibration_table_generation_enable = True # Enable/Disable INT8 calibration - - # MIGraphX EP INT8 settings - os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "1" # Enable INT8 precision - os.environ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name - os.environ["ORT_MIGRAPHX_INT8_NATIVE_CALIBRATION_TABLE"] = "0" # Calibration table name + flags.calibration_table = "calibration_cal"+ str(flags.cal_size) + precision + ".flatbuffers" + flags.native_calibration_table = "False" + if os.path.isfile("./" + flags.calibration_table): + calibration_table_generation = False + print("Found previous calibration: " + flags.calibration_table + "Skipping generating table") execution_provider = ["MIGraphXExecutionProvider"] @@ -396,25 +561,46 @@ def get_dataset_size(dataset_path, calibration_dataset_size): start_index=0, end_index=calibration_dataset_size, stride=calibration_dataset_size, - batch_size=batch_size, + batch_size=1, model_path=augmented_model_path, input_name=input_name) calibrator.collect_data(data_reader) cal_tensors = calibrator.compute_data() - serial_cal_tensors = {} - for keys, values in cal_tensors.data.items(): - serial_cal_tensors[keys] = [float(x[0]) for x in values.range_value] + class TensorDataWrapper: + def __init__(self, data_dict): + self.data_dict = data_dict + + def to_dict(self): + return self.data_dict + + def __repr__(self): + return repr(self.data_dict) + + def __serializable__(self): + return self.data_dict + + calibration_data = {} + for k, v in cal_tensors.data.items(): + if hasattr(v, 'to_dict'): + tensor_dict = v.to_dict() + processed_dict = {} + for dk, dv in tensor_dict.items(): + if isinstance(dv, np.ndarray): + processed_dict[dk] = dv.item() if dv.size == 1 else dv.tolist() + elif isinstance(dv, np.number): + processed_dict[dk] = dv.item() + else: + processed_dict[dk] = dv + calibration_data[k] = TensorDataWrapper(processed_dict) + else: + calibration_data[k] = v - print("Writing calibration table") - write_calibration_table(serial_cal_tensors) + print("Writing calibration table to:" + flags.calibration_table) + custom_write_calibration_table(calibration_data, flags.calibration_table) + os.rename("./calibration.flatbuffers", flags.calibration_table) print("Write complete") - if flags.fp16: - os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "1" - else: - os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0" - # Run prediction in MIGraphX EP138G data_reader = ImageNetDataReader(ilsvrc2012_dataset_path, start_index=calibration_dataset_size, @@ -427,14 +613,9 @@ def get_dataset_size(dataset_path, calibration_dataset_size): synset_id = data_reader.get_synset_id(ilsvrc2012_dataset_path, calibration_dataset_size, prediction_dataset_size) # Generate synset id print("Prepping Evalulator") - evaluator = ImageClassificationEvaluator(new_model_path, synset_id, data_reader, providers=execution_provider) + evaluator = ImageClassificationEvaluator(new_model_path, synset_id, flags, data_reader, providers=execution_provider) print("Performing Predictions") evaluator.predict() print("Read out answer") result = evaluator.get_result() evaluator.evaluate(result) - - #Set OS flags to off to ensure we don't interfere with other test runs - - os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0" - os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "0" diff --git a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py index cd891ff94..11df2761e 100644 --- a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py +++ b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py @@ -277,6 +277,109 @@ def get_op_nodes_not_followed_by_specific_op(model, op1, op2): return not_selected_op1_nodes +def custom_write_calibration_table(calibration_cache, filename): + """ + Helper function to write calibration table to files. + """ + + import json + import logging + import flatbuffers + import numpy as np + + import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue + import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable + from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData + + logging.info(f"calibration cache: {calibration_cache}") + + class MyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (TensorData, TensorsData)): + return obj.to_dict() + if isinstance(obj, TensorDataWrapper): + return obj.data_dict + if isinstance(obj, np.ndarray): + return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"} + if isinstance(obj, CalibrationMethod): + return {"CLS": obj.__class__.__name__, "value": str(obj)} + return json.JSONEncoder.default(self, obj) + + json_data = json.dumps(calibration_cache, cls=MyEncoder) + + with open(filename, "w") as file: + file.write(json_data) # use `json.loads` to do the reverse + + # Serialize data using FlatBuffers + zero = np.array(0) + builder = flatbuffers.Builder(1024) + key_value_list = [] + + for key in sorted(calibration_cache.keys()): + values = calibration_cache[key] + d_values = values.to_dict() + + highest = d_values.get("highest", zero) + lowest = d_values.get("lowest", zero) + + highest_val = highest.item() if hasattr(highest, "item") else float(highest) + lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest) + + floats = [float(highest_val), float(lowest_val)] + + value = str(max(floats)) + + flat_key = builder.CreateString(key) + flat_value = builder.CreateString(value) + + KeyValue.KeyValueStart(builder) + KeyValue.KeyValueAddKey(builder, flat_key) + KeyValue.KeyValueAddValue(builder, flat_value) + key_value = KeyValue.KeyValueEnd(builder) + + key_value_list.append(key_value) + + + TrtTable.TrtTableStartDictVector(builder, len(key_value_list)) + for key_value in key_value_list: + builder.PrependUOffsetTRelative(key_value) + main_dict = builder.EndVector() + + TrtTable.TrtTableStart(builder) + TrtTable.TrtTableAddDict(builder, main_dict) + cal_table = TrtTable.TrtTableEnd(builder) + + builder.Finish(cal_table) + buf = builder.Output() + + with open(filename, "wb") as file: + file.write(buf) + + # Deserialize data (for validation) + if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"): + cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0) + dict_len = cal_table.DictLength() + for i in range(dict_len): + key_value = cal_table.Dict(i) + logging.info(key_value.Key()) + logging.info(key_value.Value()) + + # write plain text + with open(filename + ".cache", "w") as file: + for key in sorted(calibration_cache.keys()): + values = calibration_cache[key] + d_values = values.to_dict() + highest = d_values.get("highest", zero) + lowest = d_values.get("lowest", zero) + + highest_val = highest.item() if hasattr(highest, "item") else float(highest) + lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest) + + floats = [float(highest_val), float(lowest_val)] + + value = key + " " + str(max(floats)) + file.write(value) + file.write("\n") def parse_input_args(): parser = argparse.ArgumentParser() @@ -297,6 +400,14 @@ def parse_input_args(): help='Perform int8 quantization on the model before running inference', ) + parser.add_argument( + "--fp8", + action="store_true", + required=False, + default=False, + help='Perform fp8 quantization on the model before running inference', + ) + parser.add_argument( "--ep", action="store", @@ -315,6 +426,15 @@ def parse_input_args(): help='The desired execution provider [MIGraphX, ROCm, CPU] for int8 quantization; Default is MIGraphX', ) + parser.add_argument( + "--calibration_table", + action="store", + required=False, + default="bert_calibration_table_100_int8.flatbuffers", + type=str, + help='use a previously created calibration table" default is bert_calibration_table_100_int8.flatbuffers', + ) + parser.add_argument( "--model", action="store", @@ -434,7 +554,7 @@ def output_run_config(flags, samples): print ("filename:" + flags.model) print ("Samples: " + str(samples) + " Batch size: " + str(flags.batch)) print ("Sequence length: " + str(flags.seq_len)) - print ("Model Quantization: fp16:" + str(flags.fp16) + " int8:" + str(flags.int8)) + print ("Model Quantization: fp16:" + str(flags.fp16) + " int8:" + str(flags.int8) + " fp8:" + str(flags.fp8)) if flags.int8: if flags.ort_quant: print ("Quantizer: Onnxruntime") @@ -524,43 +644,95 @@ def output_run_config(flags, samples): samples = flags.batch model_quants = "" + provider_args = {} + + if flags.int8 and flags.fp8: + print("INT8 and FP8 quantization is mutually exclusive for calibration") + exit() + precision="" if flags.int8: - model = onnx.load_model(model_path) + precision = precision + "_int8" + provider_args["migraphx_int8_enable"] = str(True) - # Generate INT8 calibration cache - print("Calibration data compute starts with " + str(cal_ep)) - calibrator = create_calibrator(model_path, op_types_to_quantize, augmented_model_path=augmented_model_path, calibrate_method=CalibrationMethod.Percentile) - calibrator.set_execution_providers([cal_ep]) + if flags.fp8 : + precision = precision + "_fp8" + provider_args["migraphx_fp8_enable"] = str(True) - ''' - We can use one data reader to do data pre-processing, however, - some machines don't have sufficient memory to hold all dataset and all intermediate output, - especially using 'Entropy' or 'Percentile' calibrator which collects histogram for tensors. - So let multiple data readers to handle different stride of dataset to avoid OOM. - ''' - stride = 10 - #for i in range(0, calib_num, stride): - data_reader = BertDataReader(model_path, input_dataset, input_tokens, batch_size, sequence_lengths[-1], flags.query_len, doc_stride[-1], start_index=0, end_index=calib_num) - calibrator.collect_data(data_reader) + if flags.int8 or flags.fp8: + model = onnx.load_model(model_path) + provider_args["migraphx_int8_calibration_table_name"] = str(flags.calibration_table) + if os.path.isfile("./" + flags.calibration_table): + print("Found previous calibration: " + flags.calibration_table + "Skipping generating table") + provider_args["migraphx_int8_calibration_table_name"] = str(flags.calibration_table) + else: + calibration_table_name = "bert_calibration_table_"+ str(flags.cal_num) + precision + ".flatbuffers" + print("Unable to find " + flags.calibration_table + " Generating Table: " + calibration_table_name) + provider_args["migraphx_int8_calibration_table_name"] = calibration_table_name + + # Generate INT8 calibration cache + print("Calibration data compute starts with " + str(cal_ep)) + calibrator = create_calibrator(model_path, op_types_to_quantize, augmented_model_path=augmented_model_path, calibrate_method=CalibrationMethod.Percentile) + calibrator.set_execution_providers([cal_ep]) + + ''' + We can use one data reader to do data pre-processing, however, + some machines don't have sufficient memory to hold all dataset and all intermediate output, + especially using 'Entropy' or 'Percentile' calibrator which collects histogram for tensors. + So let multiple data readers to handle different stride of dataset to avoid OOM. + ''' + stride = 10 + #for i in range(0, calib_num, stride): + data_reader = BertDataReader(model_path, input_dataset, input_tokens, 1, sequence_lengths[-1], flags.query_len, doc_stride[-1], start_index=0, end_index=calib_num) + calibrator.collect_data(data_reader) + + compute_range = calibrator.compute_data() + + + calibration_table = {} + print("Writing calibration table") + class TensorDataWrapper: + def __init__(self, data_dict): + self.data_dict = data_dict + + def to_dict(self): + return self.data_dict + + def __repr__(self): + return repr(self.data_dict) + + def __serializable__(self): + return self.data_dict + + calibration_data = {} + for k, v in compute_range.data.items(): + if hasattr(v, 'to_dict'): + tensor_dict = v.to_dict() + processed_dict = {} + for dk, dv in tensor_dict.items(): + if isinstance(dv, np.ndarray): + processed_dict[dk] = dv.item() if dv.size == 1 else dv.tolist() + elif isinstance(dv, np.number): + processed_dict[dk] = dv.item() + else: + processed_dict[dk] = dv + calibration_data[k] = TensorDataWrapper(processed_dict) + else: + calibration_data[k] = v - compute_range = calibrator.compute_data() + print("Using custom calibration table function") + custom_write_calibration_table(calibration_data, calibration_table_name) - # ORT returns data as return TensorsData(cal, self.collector.compute_collection_result()) - # Need to fix this for serialization but also convert values to float from float32 in order for JSON to correctly - # write out calibration table - json_compute_range = {} - for k, v in compute_range.data.items(): - json_compute_range[k] = (float(v.range_value[0]), float(v.range_value[1])) + print("Calibration is done. Calibration cache is saved to calibration.json") + print("Calibration is done. Calibration cache is saved to " + calibration_table_name) + provider_args["migraphx_int8_calibration_table_name"] = calibration_table_name - write_calibration_table(json_compute_range) - print("Calibration is done. Calibration cache is saved to calibration.json") - model_quants = model_quants + "_int8" + model_quants = model_quants + precision if flags.ort_quant: - print("Int8 Quantization Done with Onnxruntime Quantizer") + print(precision + " Quantization Done with Onnxruntime Quantizer") mode = QuantizationMode.QLinearOps # In TRT, it recommended to add QDQ pair to inputs of Add node followed by ReduceMean node. # Mirroring here what TRT does in MIGraphX Quantization to be able to perform an apples to apples comparison @@ -583,31 +755,23 @@ def output_run_config(flags, samples): print("QDQ model is saved to ", qdq_model_path) else: qdq_model_path = model_path - print("Int8 Quantization Done with " + cal_ep) - #Quantize with MIGraphX's INT8 quantizer instead - os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "1" # Enable MIGRAPHX INT8 precision - os.environ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name - os.environ["ORT_MIGRAPHX_INT8_NATIVE_CALIBRATION_TABLE"] = "0" # Calibration table name + print(precision + " Quantization Done with " + cal_ep) + #Quantize with MIGraphX's INT8/FP8 quantizer instead else: qdq_model_path = model_path - os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "0" # Disable MIGRAPHX INT8 precision # No fp16 cal needed, MIGraphX will handle that through Onnxruntime & MIGraphX Execution Provider during compile if flags.fp16: - os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "1" # Enable MIGRAPHX FP16 precision model_quants = model_quants + "_fp16" - else: - os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0" # Disable MIGRAPHX FP16 precision + provider_args["migraphx_fp16_enable"] = str(True) + model_name = "" if flags.save_load: model_name = str(qdq_model_path) + "_s" + str(flags.seq_len) + "_b" + str(flags.batch) + str(model_quants) + ".mxr" print("save load model from " + str(model_name)) - os.environ["ORT_MIGRAPHX_SAVE_COMPILED_MODEL"] = "1" - os.environ["ORT_MIGRAPHX_LOAD_COMPILED_MODEL"] = "1" - os.environ["ORT_MIGRAPHX_SAVE_COMPILE_PATH"] = model_name - os.environ["ORT_MIGRAPHX_LOAD_COMPILE_PATH"] = model_name + provider_args["migraphx_model_cache_dir"] = model_name - # QDQ model inference and get SQUAD prediction + # QDQ model inference and get SQUAD prediction batch_size = flags.batch data_reader = BertDataReader(qdq_model_path, input_dataset, input_tokens, batch_size, sequence_lengths[-1], flags.query_len, doc_stride[-1], end_index=samples) sess_options = onnxruntime.SessionOptions() @@ -616,7 +780,8 @@ def output_run_config(flags, samples): sess_options.log_verbosity_level = 0 sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - ort_session = onnxruntime.InferenceSession(qdq_model_path, sess_options=sess_options, providers=[ep]) + ort_session = onnxruntime.InferenceSession(qdq_model_path, sess_options=sess_options, + providers=[("MIGraphXExecutionProvider", provider_args)]) print("Running Inferences") latency = [] #Used for timing information