Skip to content

Commit 852b648

Browse files
authored
[mlir] Improvements to the 'quant' dialect (#100667)
Full revamp of the 'quant' dialect. This is an implementation for the RFC at https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942
1 parent f0162fc commit 852b648

38 files changed

+2886
-271
lines changed
Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,2 @@
1-
add_mlir_dialect(QuantOps quant)
2-
add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc)
3-
4-
set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td)
5-
mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant")
6-
add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen)
1+
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
add_mlir_dialect(QuantOps quant)
2+
add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc)
3+
4+
set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td)
5+
mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant")
6+
add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen)
Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
//===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===//
1+
//===- Quant.h - Quantization Ops -------------------------------*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#ifndef MLIR_DIALECT_QUANT_QUANTOPS_H_
10-
#define MLIR_DIALECT_QUANT_QUANTOPS_H_
9+
#ifndef MLIR_DIALECT_QUANT_IR_QUANT_H_
10+
#define MLIR_DIALECT_QUANT_IR_QUANT_H_
1111

1212
#include "mlir/IR/Attributes.h"
1313
#include "mlir/IR/Builders.h"
@@ -19,9 +19,19 @@
1919
#include "mlir/Interfaces/SideEffectInterfaces.h"
2020
#include "llvm/Support/MathExtras.h"
2121

22-
#include "mlir/Dialect/Quant/QuantOpsDialect.h.inc"
22+
#include "mlir/Dialect/Quant/IR/QuantOpsDialect.h.inc"
23+
24+
namespace mlir {
25+
namespace quant {
26+
27+
class QuantizedType;
28+
class UniformQuantizedType;
29+
class UniformQuantizedPerAxisType;
30+
31+
} // namespace quant
32+
} // namespace mlir
2333

2434
#define GET_OP_CLASSES
25-
#include "mlir/Dialect/Quant/QuantOps.h.inc"
35+
#include "mlir/Dialect/Quant/IR/QuantOps.h.inc"
2636

27-
#endif // MLIR_DIALECT_QUANT_QUANTOPS_H_
37+
#endif // MLIR_DIALECT_QUANT_IR_QUANT_H_
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
//===- QuantBase.td - Quantization dialect base ------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Quantization dialect, types, and traits.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef QUANT_BASE
14+
#define QUANT_BASE
15+
16+
include "mlir/IR/OpBase.td"
17+
18+
def Quant_Dialect : Dialect {
19+
let name = "quant";
20+
let description = [{
21+
The `quant` dialect offers a framework for defining and manipulating
22+
quantized values. Central to this framework is the `!quant.uniform` data
23+
type, used to represent quantized values. This dialect also provides a
24+
suite of operations to handle and convert quantized values between their
25+
original floating-point representations and the optimized, lower bit-width
26+
integer representations. The `quant` dialect is instrumented with
27+
transformation passes to lower these operations into other core MLIR
28+
dialects, while also flattening all occurrences of quantized types into
29+
their integer counterparts.
30+
31+
32+
## The `!quant.uniform` type
33+
34+
The quantization process establishes a relationship between two types of
35+
values: an *expressed value* and a *stored value*. The former refers to the
36+
floating-point representation used in an original machine learning model,
37+
capturing the precise numerical characteristics needed for accurate
38+
calculations. The latter is the simplified integer representation that
39+
resides in memory after quantization. The `!quant.uniform` data type
40+
encodes the necessary information for (lossy) round-trip conversion between
41+
an expressed and a stored value.
42+
43+
The `quant.uniform` type has two variants: per-layer quantization and
44+
per-channel (or per-axis) quantization. In per-layer quantization, the
45+
quantization information affects an entire tensor uniformly. Conversely, in
46+
per-channel quantization, the data type encodes the specific tensor axis
47+
that serves as the channel and includes quantization information for each
48+
individual channel within the tensor. Below are the specific syntactic and
49+
semantic considerations for each modality.
50+
51+
52+
### Per-layer quantization
53+
54+
This is the general syntax of the `!quant.uniform` type representing
55+
per-layer quantization:
56+
57+
```
58+
`!quant.uniform` `<`
59+
storedType (`<` storageMin `:` storageMax `>`)? `:`
60+
expressedType `,`
61+
scale (`:` zeroPoint)?
62+
`>`
63+
```
64+
65+
The type contains the following parameters:
66+
67+
- `storedType`: Integer type of the value stored in memory. This type
68+
conveys the bit width and signedness of the quantized stored value.
69+
Signed integer types are represented as `'i' bitWidth` (e.g., `i8`),
70+
while unsigned integer types are represented as `'u' bitWidth` (e.g.,
71+
`u8`).
72+
73+
- `storageMin`, `storageMax`: Optional bounds for the stored value. If
74+
given, they must be within the range of `storedType`. If omitted, the
75+
entire range of `storedType` is allowed (e.g., `-128...127` for `i8` or
76+
`0...255` for `u8`).
77+
78+
- `expressedType`: Floating-point type of the value expressed by this
79+
quantized type (e.g., `f32`, `f80`, `bf16`, or `tf32`).
80+
81+
- `scale`: Floating-point value of type `expressedType` used in the
82+
conversion between stored and expressed values.
83+
84+
- `zeroPoint`: Optional integer value of type `storageType` used in the
85+
conversion between stored and expressed values. If omitted, the default
86+
is 0.
87+
88+
Type conversions, rounding methods, and clamping actions aside, the
89+
relationship between the expressed and stored values as encoded in a
90+
quantized type is denoted by the following formula:
91+
92+
$$
93+
expressedValue = (storedValue ~-~ zeroPoint) ~\times~ scale
94+
$$
95+
96+
Operations `quant.qcast` (quantize cast) and `quant.dcast` (dequantize
97+
cast) can be used to quantize a floating-point value and dequantize a
98+
stored value, respectively. See the documentation for these operations for
99+
details on how the quantization and dequantization processes are influenced
100+
by the `!quant.uniform` type parameters.
101+
102+
Here are some examples of the use of `!quant.uniform` with per-layer
103+
quantization:
104+
105+
```
106+
// An 8-bit signed integer type is used to represent a 32-bit float. No
107+
// clamping information is provided, so the full [-128, 127] range is
108+
// available. The scale is set to 3.0, and the zero point takes its default
109+
// 0 value.
110+
!quant.uniform<i8:f32, 3.0>
111+
112+
// A 16-bit unsigned integer type is used to represent a 32-bit float. Out
113+
// of the 16 bits, only 10 are used, acoording to the 0..1023 clamping
114+
// range. The type sets the scale to 1.23 and the zero point to 512.
115+
!quant.uniform<u16<0:1023>:f32, 1.23:512>
116+
```
117+
118+
### Per-channel quantization
119+
120+
The general syntax of the `!quant.uniform` type representing per-channel
121+
quantization is as follows:
122+
123+
```
124+
`!quant.uniform` `<`
125+
storedType (`<` storageMin `:` storageMax `>`)? `:`
126+
expressedType `:`
127+
channelAxis `,`
128+
`{`
129+
scale0 (`:` zeroPoint0)? `,`
130+
scale1 (`:` zeroPoint1)? ...
131+
'}'
132+
`>`
133+
```
134+
135+
In this data type, there are multiple pairs of `scale` and `zeroPoint`
136+
values. The `channelAxis` field represents the dimension of the containing
137+
tensor acting as the channel. The size of the tensor along this dimension
138+
is expected to match the number of provided `scale`-`zeroPoint` pairs, and
139+
a given pair *i* applies to all elements in the tensor whose index along
140+
dimension `channelAxis` is *i*. A quantized data type using per-channel
141+
quantization is always expected to be contained within a tensor type.
142+
143+
Here are some examples:
144+
145+
```
146+
// A 2x3x4 tensor contains 8-bit signed integers representing 32-bit
147+
// floats. Dimension 1 of the tensor acts as the channel dimension. Its
148+
// size 3 matches the number of provided scale values. Tensor elemenets at
149+
// positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and
150+
// 5.0, respectively.
151+
tensor<2x3x4x!quant.uniform<i8:f32:1, {3.0, 4.0, 5.0}>>
152+
153+
// A 2D dynamically sized tensor contains 16-bit unsigned integers
154+
// representing 32-bit floats. Dimension 0 of the tensor acts as the
155+
// channel dimension. Since 2 scale and zero-point values are provided, the
156+
// size of dimension 0 is expected to be 2 at runtime. Tensor elements
157+
// [0][*] use scale 2.0 and zero point 10, while elements [1][*] use scale
158+
// 3.0 and zero point 20.
159+
tensor<?x?x!quant.uniform<u16:f32:0, {2.0:10, 3.0:20}>>
160+
```
161+
162+
163+
## Per-axis quantization integrity
164+
165+
When type `!quant.uniform` contains per-axis quantization information, the
166+
rules below are enforced. These rules guarantee that the quantization
167+
information encoded in the data type is applicable to the context in which
168+
the quantized type is used. For efficiency, these rules are actively
169+
enforced by the verifiers of `quant` dialect ops, but they must be
170+
respected in any context in which the `!quant.uniform` data type is used,
171+
such as the header of a `func.func` op, or the input of an arithmetic
172+
operation.
173+
174+
- A quantized type with per-channel quantization information must be the
175+
element type of a tensor container type, and may not occur directly as
176+
the data type of a scalar value.
177+
178+
```
179+
// Incorrect. Type !quant.uniform specifies per-channel quantization for a
180+
// scalar type.
181+
%result = quant.qcast %input : f32 to !quant.uniform<i8:f32:0, {1.0, 2.0}>
182+
183+
// Correct. Type `!quant.uniform` with per-channel quantization is wrapped
184+
// in a `tensor` type.
185+
%result = quant.qcast %input : tensor<2xf32> to tensor<2x!quant.uniform<i8:f32:0, {1.0, 2.0}>>
186+
```
187+
188+
- If the tensor containing the `!quant.uniform` type is ranked, its rank
189+
must be greater than the channel axis specified in the quantized type.
190+
191+
```
192+
// Incorrect. The tensor rank (2) is not greater than the channel axis in
193+
// the quantized type (3).
194+
%result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:3, {1.0, 2.0}>>
195+
196+
// Correct. The tensor rank (2) is now greater than the channel axis (1):
197+
%result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:1, {1.0, 2.0}>>
198+
```
199+
200+
- If the axis dimension in the containing tensor is static, its size must
201+
be equal to the number of scales present in the quantized type.
202+
203+
```
204+
// Incorrect. The channel axis is 1, and the size of dimension 1 in the
205+
// containing tensor is 3. However, there are 4 scale values present in the
206+
// quantized type.
207+
%result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {1.0, 2.0, 3.0, 4.0}>>
208+
209+
// Correct. The quantized type now includes 3 scale values, matching the
210+
// size of dimension 1 of the result tensor.
211+
%result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
212+
```
213+
}];
214+
let cppNamespace = "::mlir::quant";
215+
let useDefaultTypePrinterParser = 1;
216+
}
217+
218+
219+
//===----------------------------------------------------------------------===//
220+
// Type predicates
221+
//===----------------------------------------------------------------------===//
222+
223+
class quant_ScalarOrTensorOf<Type etype> :
224+
Type<Or<[etype.predicate, TensorOf<[etype]>.predicate]>,
225+
"scalar or tensor of " # etype.summary>;
226+
227+
def quant_QuantizedType :
228+
Type<CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">, "quantized type">;
229+
230+
def quant_ScalarType :
231+
Type<Or<[
232+
AnySignlessInteger.predicate,
233+
AnyFloat.predicate,
234+
quant_QuantizedType.predicate
235+
]>,
236+
"signless integer, float, or quantized scalar">;
237+
238+
def quant_IntegerOrQuantizedType :
239+
Type<Or<[
240+
AnySignlessInteger.predicate,
241+
quant_QuantizedType.predicate
242+
]>,
243+
"signless integer or quantized type">;
244+
245+
def quant_FloatScalarOrTensor :
246+
quant_ScalarOrTensorOf<AnyFloat>;
247+
248+
def quant_IntegerScalarOrTensor :
249+
quant_ScalarOrTensorOf<AnySignlessInteger>;
250+
251+
def quant_QuantizedScalarOrTensor :
252+
quant_ScalarOrTensorOf<quant_QuantizedType>;
253+
254+
def quant_IntegerOrQuantizedScalarOrTensor :
255+
quant_ScalarOrTensorOf<quant_IntegerOrQuantizedType>;
256+
257+
258+
//===----------------------------------------------------------------------===//
259+
// Traits
260+
//===----------------------------------------------------------------------===//
261+
262+
def quant_SameScalarOrTensorShape :
263+
PredOpTrait<
264+
"input and result are both scalars or both tensors with matching shape",
265+
Or<[
266+
And<[
267+
TypeIsPred<"input", quant_ScalarType>,
268+
TypeIsPred<"result", quant_ScalarType>
269+
]>,
270+
And<[
271+
TypeIsPred<"input", AnyUnrankedTensor>,
272+
TypeIsPred<"result", AnyUnrankedTensor>
273+
]>,
274+
And<[
275+
TypeIsPred<"input", AnyRankedTensor>,
276+
TypeIsPred<"result", AnyRankedTensor>,
277+
AllShapesMatch<["input", "result"]>.predicate
278+
]>
279+
]>
280+
>;
281+
282+
def quant_IntegerAndQuantizedCombination :
283+
PredOpTrait<
284+
"input must be integer and result must be quantized, or vice versa",
285+
Or<[
286+
And<[
287+
TypeIsPred<"input", quant_QuantizedScalarOrTensor>,
288+
TypeIsPred<"result", quant_IntegerScalarOrTensor>
289+
]>,
290+
And<[
291+
TypeIsPred<"input", quant_IntegerScalarOrTensor>,
292+
TypeIsPred<"result", quant_QuantizedScalarOrTensor>
293+
]>
294+
]>
295+
>;
296+
297+
#endif // QUANT_BASE

0 commit comments

Comments
 (0)