11
11
12
12
from __future__ import annotations
13
13
14
+ from typing import Optional
15
+
14
16
from onnxscript .function_libs .torch_lib .ops import common
15
17
from onnxscript .function_libs .torch_lib .registration import torch_op
16
18
from onnxscript .onnx_opset import opset18 as op
17
19
from onnxscript .onnx_opset import opset23 as op23
18
20
from onnxscript .onnx_types import TensorType
19
- from typing import Optional
20
21
21
22
22
23
@torch_op (
@@ -84,7 +85,7 @@ def quantized_decomposed_quantize_per_channel(
84
85
) -> TensorType :
85
86
"""Affine per channel quantization for the Tensor using the same quantization
86
87
parameters for each channel/axis to map from floating point to quantized values.
87
-
88
+
88
89
Uses ONNX QuantizeLinear with per-axis quantization support.
89
90
"""
90
91
# Use opset23 for per-axis quantization support
@@ -111,7 +112,7 @@ def quantized_decomposed_dequantize_per_channel(
111
112
) -> TensorType :
112
113
"""Affine per channel dequantization for the Tensor using the same quantization
113
114
parameters for each channel/axis to map from quantized values to floating point values.
114
-
115
+
115
116
Uses ONNX DequantizeLinear with per-axis quantization support.
116
117
"""
117
118
# Use opset23 for per-axis quantization support with optional output_dtype
@@ -120,4 +121,6 @@ def quantized_decomposed_dequantize_per_channel(
120
121
return op23 .DequantizeLinear (input , scales , zero_points , axis = axis )
121
122
else :
122
123
assert out_dtype > 0 , f"out_dtype must be -1 or > 0 not { out_dtype } "
123
- return op23 .DequantizeLinear (input , scales , zero_points , axis = axis , output_dtype = out_dtype )
124
+ return op23 .DequantizeLinear (
125
+ input , scales , zero_points , axis = axis , output_dtype = out_dtype
126
+ )
0 commit comments