diff --git a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt index c08f399ee182d..9f57627c321fb 100644 --- a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt @@ -1,6 +1,2 @@ -add_mlir_dialect(QuantOps quant) -add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc) - -set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td) -mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant") -add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen) +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt new file mode 100644 index 0000000000000..c08f399ee182d --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect(QuantOps quant) +add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td) +mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant") +add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen) diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h similarity index 59% rename from mlir/include/mlir/Dialect/Quant/QuantOps.h rename to mlir/include/mlir/Dialect/Quant/IR/Quant.h index 14fb3035ab0d3..11a969a3ee519 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantOps.h +++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h @@ -1,4 +1,4 @@ -//===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===// +//===- Quant.h - Quantization Ops -------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_QUANT_QUANTOPS_H_ -#define MLIR_DIALECT_QUANT_QUANTOPS_H_ +#ifndef MLIR_DIALECT_QUANT_IR_QUANT_H_ +#define MLIR_DIALECT_QUANT_IR_QUANT_H_ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -19,9 +19,19 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/Support/MathExtras.h" -#include "mlir/Dialect/Quant/QuantOpsDialect.h.inc" +#include "mlir/Dialect/Quant/IR/QuantOpsDialect.h.inc" + +namespace mlir { +namespace quant { + +class QuantizedType; +class UniformQuantizedType; +class UniformQuantizedPerAxisType; + +} // namespace quant +} // namespace mlir #define GET_OP_CLASSES -#include "mlir/Dialect/Quant/QuantOps.h.inc" +#include "mlir/Dialect/Quant/IR/QuantOps.h.inc" -#endif // MLIR_DIALECT_QUANT_QUANTOPS_H_ +#endif // MLIR_DIALECT_QUANT_IR_QUANT_H_ diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td new file mode 100644 index 0000000000000..791cb9de48d05 --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td @@ -0,0 +1,297 @@ +//===- QuantBase.td - Quantization dialect base ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Quantization dialect, types, and traits. +// +//===----------------------------------------------------------------------===// + +#ifndef QUANT_BASE +#define QUANT_BASE + +include "mlir/IR/OpBase.td" + +def Quant_Dialect : Dialect { + let name = "quant"; + let description = [{ + The `quant` dialect offers a framework for defining and manipulating + quantized values. Central to this framework is the `!quant.uniform` data + type, used to represent quantized values. This dialect also provides a + suite of operations to handle and convert quantized values between their + original floating-point representations and the optimized, lower bit-width + integer representations. The `quant` dialect is instrumented with + transformation passes to lower these operations into other core MLIR + dialects, while also flattening all occurrences of quantized types into + their integer counterparts. + + + ## The `!quant.uniform` type + + The quantization process establishes a relationship between two types of + values: an *expressed value* and a *stored value*. The former refers to the + floating-point representation used in an original machine learning model, + capturing the precise numerical characteristics needed for accurate + calculations. The latter is the simplified integer representation that + resides in memory after quantization. The `!quant.uniform` data type + encodes the necessary information for (lossy) round-trip conversion between + an expressed and a stored value. + + The `quant.uniform` type has two variants: per-layer quantization and + per-channel (or per-axis) quantization. In per-layer quantization, the + quantization information affects an entire tensor uniformly. Conversely, in + per-channel quantization, the data type encodes the specific tensor axis + that serves as the channel and includes quantization information for each + individual channel within the tensor. Below are the specific syntactic and + semantic considerations for each modality. + + + ### Per-layer quantization + + This is the general syntax of the `!quant.uniform` type representing + per-layer quantization: + + ``` + `!quant.uniform` `<` + storedType (`<` storageMin `:` storageMax `>`)? `:` + expressedType `,` + scale (`:` zeroPoint)? + `>` + ``` + + The type contains the following parameters: + + - `storedType`: Integer type of the value stored in memory. This type + conveys the bit width and signedness of the quantized stored value. + Signed integer types are represented as `'i' bitWidth` (e.g., `i8`), + while unsigned integer types are represented as `'u' bitWidth` (e.g., + `u8`). + + - `storageMin`, `storageMax`: Optional bounds for the stored value. If + given, they must be within the range of `storedType`. If omitted, the + entire range of `storedType` is allowed (e.g., `-128...127` for `i8` or + `0...255` for `u8`). + + - `expressedType`: Floating-point type of the value expressed by this + quantized type (e.g., `f32`, `f80`, `bf16`, or `tf32`). + + - `scale`: Floating-point value of type `expressedType` used in the + conversion between stored and expressed values. + + - `zeroPoint`: Optional integer value of type `storageType` used in the + conversion between stored and expressed values. If omitted, the default + is 0. + + Type conversions, rounding methods, and clamping actions aside, the + relationship between the expressed and stored values as encoded in a + quantized type is denoted by the following formula: + + $$ + expressedValue = (storedValue ~-~ zeroPoint) ~\times~ scale + $$ + + Operations `quant.qcast` (quantize cast) and `quant.dcast` (dequantize + cast) can be used to quantize a floating-point value and dequantize a + stored value, respectively. See the documentation for these operations for + details on how the quantization and dequantization processes are influenced + by the `!quant.uniform` type parameters. + + Here are some examples of the use of `!quant.uniform` with per-layer + quantization: + + ``` + // An 8-bit signed integer type is used to represent a 32-bit float. No + // clamping information is provided, so the full [-128, 127] range is + // available. The scale is set to 3.0, and the zero point takes its default + // 0 value. + !quant.uniform + + // A 16-bit unsigned integer type is used to represent a 32-bit float. Out + // of the 16 bits, only 10 are used, acoording to the 0..1023 clamping + // range. The type sets the scale to 1.23 and the zero point to 512. + !quant.uniform:f32, 1.23:512> + ``` + + ### Per-channel quantization + + The general syntax of the `!quant.uniform` type representing per-channel + quantization is as follows: + + ``` + `!quant.uniform` `<` + storedType (`<` storageMin `:` storageMax `>`)? `:` + expressedType `:` + channelAxis `,` + `{` + scale0 (`:` zeroPoint0)? `,` + scale1 (`:` zeroPoint1)? ... + '}' + `>` + ``` + + In this data type, there are multiple pairs of `scale` and `zeroPoint` + values. The `channelAxis` field represents the dimension of the containing + tensor acting as the channel. The size of the tensor along this dimension + is expected to match the number of provided `scale`-`zeroPoint` pairs, and + a given pair *i* applies to all elements in the tensor whose index along + dimension `channelAxis` is *i*. A quantized data type using per-channel + quantization is always expected to be contained within a tensor type. + + Here are some examples: + + ``` + // A 2x3x4 tensor contains 8-bit signed integers representing 32-bit + // floats. Dimension 1 of the tensor acts as the channel dimension. Its + // size 3 matches the number of provided scale values. Tensor elemenets at + // positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and + // 5.0, respectively. + tensor<2x3x4x!quant.uniform> + + // A 2D dynamically sized tensor contains 16-bit unsigned integers + // representing 32-bit floats. Dimension 0 of the tensor acts as the + // channel dimension. Since 2 scale and zero-point values are provided, the + // size of dimension 0 is expected to be 2 at runtime. Tensor elements + // [0][*] use scale 2.0 and zero point 10, while elements [1][*] use scale + // 3.0 and zero point 20. + tensor> + ``` + + + ## Per-axis quantization integrity + + When type `!quant.uniform` contains per-axis quantization information, the + rules below are enforced. These rules guarantee that the quantization + information encoded in the data type is applicable to the context in which + the quantized type is used. For efficiency, these rules are actively + enforced by the verifiers of `quant` dialect ops, but they must be + respected in any context in which the `!quant.uniform` data type is used, + such as the header of a `func.func` op, or the input of an arithmetic + operation. + + - A quantized type with per-channel quantization information must be the + element type of a tensor container type, and may not occur directly as + the data type of a scalar value. + + ``` + // Incorrect. Type !quant.uniform specifies per-channel quantization for a + // scalar type. + %result = quant.qcast %input : f32 to !quant.uniform + + // Correct. Type `!quant.uniform` with per-channel quantization is wrapped + // in a `tensor` type. + %result = quant.qcast %input : tensor<2xf32> to tensor<2x!quant.uniform> + ``` + + - If the tensor containing the `!quant.uniform` type is ranked, its rank + must be greater than the channel axis specified in the quantized type. + + ``` + // Incorrect. The tensor rank (2) is not greater than the channel axis in + // the quantized type (3). + %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform> + + // Correct. The tensor rank (2) is now greater than the channel axis (1): + %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform> + ``` + + - If the axis dimension in the containing tensor is static, its size must + be equal to the number of scales present in the quantized type. + + ``` + // Incorrect. The channel axis is 1, and the size of dimension 1 in the + // containing tensor is 3. However, there are 4 scale values present in the + // quantized type. + %result = quant.qcast %input : tensor to tensor> + + // Correct. The quantized type now includes 3 scale values, matching the + // size of dimension 1 of the result tensor. + %result = quant.qcast %input : tensor to tensor> + ``` + }]; + let cppNamespace = "::mlir::quant"; + let useDefaultTypePrinterParser = 1; +} + + +//===----------------------------------------------------------------------===// +// Type predicates +//===----------------------------------------------------------------------===// + +class quant_ScalarOrTensorOf : + Type.predicate]>, + "scalar or tensor of " # etype.summary>; + +def quant_QuantizedType : + Type($_self)">, "quantized type">; + +def quant_ScalarType : + Type, + "signless integer, float, or quantized scalar">; + +def quant_IntegerOrQuantizedType : + Type, + "signless integer or quantized type">; + +def quant_FloatScalarOrTensor : + quant_ScalarOrTensorOf; + +def quant_IntegerScalarOrTensor : + quant_ScalarOrTensorOf; + +def quant_QuantizedScalarOrTensor : + quant_ScalarOrTensorOf; + +def quant_IntegerOrQuantizedScalarOrTensor : + quant_ScalarOrTensorOf; + + +//===----------------------------------------------------------------------===// +// Traits +//===----------------------------------------------------------------------===// + +def quant_SameScalarOrTensorShape : + PredOpTrait< + "input and result are both scalars or both tensors with matching shape", + Or<[ + And<[ + TypeIsPred<"input", quant_ScalarType>, + TypeIsPred<"result", quant_ScalarType> + ]>, + And<[ + TypeIsPred<"input", AnyUnrankedTensor>, + TypeIsPred<"result", AnyUnrankedTensor> + ]>, + And<[ + TypeIsPred<"input", AnyRankedTensor>, + TypeIsPred<"result", AnyRankedTensor>, + AllShapesMatch<["input", "result"]>.predicate + ]> + ]> + >; + +def quant_IntegerAndQuantizedCombination : + PredOpTrait< + "input must be integer and result must be quantized, or vice versa", + Or<[ + And<[ + TypeIsPred<"input", quant_QuantizedScalarOrTensor>, + TypeIsPred<"result", quant_IntegerScalarOrTensor> + ]>, + And<[ + TypeIsPred<"input", quant_IntegerScalarOrTensor>, + TypeIsPred<"result", quant_QuantizedScalarOrTensor> + ]> + ]> + >; + +#endif // QUANT_BASE diff --git a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td similarity index 100% rename from mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td rename to mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td new file mode 100644 index 0000000000000..6ef925146dce6 --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td @@ -0,0 +1,243 @@ +//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the operation definition file for Quantization. +// +//===----------------------------------------------------------------------===// + +#ifndef QUANT_OPS +#define QUANT_OPS + +include "mlir/Dialect/Quant/IR/QuantBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// Base classes +//===----------------------------------------------------------------------===// + +class quant_Op traits> : + Op; + +//===----------------------------------------------------------------------===// +// Quantization casts +//===----------------------------------------------------------------------===// + +def quant_DequantizeCastOp : quant_Op<"dcast", [ + Pure, + quant_SameScalarOrTensorShape]> { + let summary = "Dequantize cast operation"; + let description = [{ + Convert an input quantized value into its expressed floating-point value. + The dequantization process consists of the following steps: + + ``` + def dequantize(quantizedValue: quantizedType) -> expressedType: + storedValue = reinterpretCast(quantizedValue, storageType) + storedValueFloat = convertIntToFloat(storedValue, expressedType) + zeroPointFloat = convertIntToFloat(zeroPoint, expressedType) + expressedValue = (storedValueFloat - zeroPointFloat) * scale + return expressedValue + ``` + + Here, `storageType`, `expressedType`, `scale`, and `zeroPoint` are obtained + from the corresponding parameters encoded in `quantizedType`. For + per-channel quantization, the appropriate `scale` and `zeroPoint` values + are used for each tensor element computation according to the channel the + element belongs to. + + The numerical results produced by the algorithm above may vary depending on + the rounding methods used by `convertIntToFloat()`, subtraction (`-`), and + multiplication (`*`). This operation does not define specific rounding + methods; instead, it is the responsibility of a transform pipeline to + determine which rounding method to apply when this operation is broken down + into lower-level dialects. + + The operation must satisfy the following syntactic constraints: + + - Operand `input` must be a scalar or tensor of type `!quant.uniform`. + + - The result type must be a floating-point scalar or tensor. + + - The `expressedType` parameter of the `!quant.uniform` type of the input + must match the floating-point type of the result. + + - The operand and result types must be both scalars or both tensors. If + tensors, they must be both ranked or both unranked. If ranked, both must + have the same shape, including matching static and dynamic dimensions. + + - If the operand uses per-channel quantization, its `!quant.uniform` type + must adhere to the [Per-axis quantization + integrity](#per-axis-quantization-integrity) guidelines. + + Examples: + + ``` + // Dequantize a scalar quantized value + %result = quant.dcast %input : !quant.uniform to f32 + + // Dequantize a dynamically shaped tensor of quantized values + %result = quant.dcast %input : tensor> to tensor + + // Dequantize an unranked tensor using per-axis quantization information + %result = quant.dcast %input : tensor<*x!quant.uniform> to tensor<*xf32> + ``` + }]; + let arguments = (ins quant_QuantizedScalarOrTensor:$input); + let results = (outs quant_FloatScalarOrTensor:$result); + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; + let hasVerifier = 1; + let hasFolder = 1; + let extraClassDeclaration = [{ + /// Return the float type of the scalar or tensor result. + FloatType getFloatType(); + + /// Return the quantized type of the scalar or tensor input. + quant::QuantizedType getQuantizedType(); + }]; +} + +def quant_QuantizeCastOp : quant_Op<"qcast", [ + Pure, + quant_SameScalarOrTensorShape]> { + let summary = "Quantize cast operation"; + let description = [{ + Convert a floating-point value to a quantized type. The quantization + process consists of the following steps: + + ``` + def quantize(expressedValue: expressedType) -> quantizedType: + zeroPointFloat = convertIntToFloat(zeroPoint, expressedType) + scaledValue = expressedValue / scale + storedValueFloat = scaledValue + zeroPointFloat + storedValue = convertFloatToInt(storedValueFloat, storageType) + storedValueClamped = clamp(storedValue, storageMin, storageMax) + quantizedValue = reinterpretCast(storedValueClamped, quantizedType) + return quantizedValue + ``` + + Here, `storageType`, `storageMin`, `storageMax`, `expressedType`, `scale`, + and `zeroPoint` are obtained from the corresponding parameters encoded in + `quantizedType`. For per-channel quantization, the appropriate `scale` and + `zeroPoint` values are used for each tensor element computation according + to the channel the element belongs to. + + The numerical results produced by the algorithm above may vary depending on + the rounding methods used by `convertIntToFloat()`, `convertFloatToInt()`, + `clamp()`, division (`/`), and addition (`+`). This operation does not + define specific rounding methods; instead, it is the responsibility of a + transform pipeline to determine which rounding method to apply when this + operation is broken down into lower-level dialects. + + The operation must satisfy the following syntactic constraints: + + - Operand `input` must be a floating-point scalar or tensor. + + - The result type must be a scalar or tensor of type `!quant.uniform`. + + - The `expressedType` parameter in the `!quant.uniform` type of the result + must match the floating-point type of the input. + + - The operand and result types must be both scalars or both tensors. If + tensors, they must be both ranked or both unranked. If ranked, both must + have the same shape, including matching static and dynamic dimensions. + + - If the result uses per-channel quantization, its `!quant.uniform` type + must adhere to the [Per-axis quantization + integrity](#per-axis-quantization-integrity) guidelines. + + Examples: + + ``` + // Quantize a scalar floating-point value + %result = quant.qcast %input : f32 to !quant.uniform + + // Quantize a dynamically shaped tensor of quantized values + %result = quant.qcast %input : tensor to tensor> + + // Quantize an unranked tensor using per-axis quantization information + %result = quant.qcast %input : tensor<*xf32> to tensor<*x!quant.uniform> + ``` + }]; + let arguments = (ins quant_FloatScalarOrTensor:$input); + let results = (outs quant_QuantizedScalarOrTensor:$result); + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; + let hasVerifier = 1; + let hasFolder = 1; + let extraClassDeclaration = [{ + /// Return the float type of the scalar or tensor input. + FloatType getFloatType(); + + /// Return the quantized type of the scalar or tensor result. + quant::QuantizedType getQuantizedType(); + }]; +} + +def quant_StorageCastOp : quant_Op<"scast", [ + Pure, + quant_SameScalarOrTensorShape, + quant_IntegerAndQuantizedCombination]> { + let summary = "Storage cast operation"; + let description = [{ + Convert a value from a quantized type to the corresponding signless integer + storage type, or vice versa. This conversion simply involves a + reinterpretation of the input bits and does not involve any data + manipulation. + + The following syntactic restrictions must be met: + + - Operand `input` must be a scalar or tensor of a signless integer or + `!quant.uniform` type. + + - The result must be a scalar or tensor of a signless integer or + `!quant.uniform` type. + + - If the operand is a scalar or tensor of type integer, the result must be + a scalar or tensor of type `!quant.uniform`, and vice versa. + + - The operand and result must be both scalars or both tensors. If tensors, + they must be both ranked or both unranked. If ranked, both must have the + same shape, including matching static and dynamic dimensions. + + - The width of the `storageType` parameter of the quantized type of the + operand or result must match the width of the signless integer type of + the operand or result. + + - If the operand or result uses per-channel quantization, its + `!quant.uniform` type must adhere to the [Per-axis quantization + integrity](#per-axis-quantization-integrity) guidelines. + + Examples: + + ``` + // Cast a scalar quantized value into its storage type + %result = quant.scast %input : !quant.uniform to i8 + + // Cast a dynamically shaped tensor of quantized values into their storage type + %result = quant.scast %input : tensor> to tensor + + // Cast an unranked tensor of signless integers into a quantized type using + // per-channel quantization + %result = quant.scast %input : tensor<*xi8> to tensor<*x!quant.uniform> + ``` + }]; + let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input); + let results = (outs quant_IntegerOrQuantizedScalarOrTensor:$result); + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; + let hasVerifier = 1; + let hasFolder = 1; + let extraClassDeclaration = [{ + /// Return the integer type used either in the input or the result. + IntegerType getIntegerType(); + + /// Return the quantized type used either in the input or the result. + quant::QuantizedType getQuantizedType(); + }]; +} + +#endif // QUANT_OPS diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h similarity index 98% rename from mlir/include/mlir/Dialect/Quant/QuantTypes.h rename to mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h index 57a2aa2983365..43440ba623b9c 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_QUANT_QUANTTYPES_H -#define MLIR_DIALECT_QUANT_QUANTTYPES_H +#ifndef MLIR_DIALECT_QUANT_IR_QUANTTYPES_H +#define MLIR_DIALECT_QUANT_IR_QUANTTYPES_H #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -114,6 +114,10 @@ class QuantizedType : public Type { /// The maximum value that storageType can take. int64_t getStorageTypeMax() const; + /// Return whether the storage type has explicit min or max boundaries + /// different from the minimum and maximum representable values. + bool hasStorageTypeBounds() const; + /// Gets the integral bit width that the underlying storage type can exactly /// represent. For integral storage types, this will just be their width. unsigned getStorageTypeIntegralWidth() const; @@ -413,4 +417,4 @@ class CalibratedQuantizedType } // namespace quant } // namespace mlir -#endif // MLIR_DIALECT_QUANT_QUANTTYPES_H +#endif // MLIR_DIALECT_QUANT_IR_QUANTTYPES_H diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/QuantOps.td deleted file mode 100644 index 7937265ce2f20..0000000000000 --- a/mlir/include/mlir/Dialect/Quant/QuantOps.td +++ /dev/null @@ -1,103 +0,0 @@ -//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is the operation definition file for Quantization. -// -//===----------------------------------------------------------------------===// - -#ifndef DIALECT_QUANT_QUANT_OPS_ -#define DIALECT_QUANT_QUANT_OPS_ - -include "mlir/Dialect/Quant/QuantOpsBase.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/SideEffectInterfaces.td" - -//===----------------------------------------------------------------------===// -// Base classes -//===----------------------------------------------------------------------===// - -class quant_Op traits> : - Op; - -//===----------------------------------------------------------------------===// -// Quantization casts -//===----------------------------------------------------------------------===// - -def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> { - let summary = "convert a quantizable type to a quantized type"; - let description = [{ - A QuantizeCast `qcast` represents a potential type shift from a quantizable - type to a quantized type. - - At runtime, a `qcast` will apply the transformation expressed by its - operand and result type. For flexibility during transformation, it is also - possible to have a `qcast` that performs no transformation (both its - operand and result type are quantizable). - - A `qcast` will typically originate from either: - a) An expressed or implied constraint in the source dialect which signals - that a certain level of quantization is possible or required. - b) An inference made by a quantization algorithm indicating that a - quantized representation may be acceptable. - - Especially early in transformation, it is common to have pairs of - `qcast` and `dcast` at points where a transition to a quantized type is - required. In addition, it is also common to have an identity `qcast` - (where the operand and result type are not quantized) at all points where - it is legal to use a quantized representation (but is not known to be - acceptable). - }]; - let arguments = (ins quant_RealValueType:$arg); - let results = (outs quant_RealValueType:$res); -} - -def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> { - let summary = "convert back from a quantized to quantizable (expressed) type operation"; - let description = [{ - A DequantizeCast op `dcast` represents the inverse of a `qcast`, - converting back from a quantized to quantizable (expressed) type. - - Like `qcast`s, a `dcast` is allowed to have both its operand and result - as non quantized types. This facilitates transformations and marks edges - where the computation must be carried out in the expressed type. - - Especially early in transformation, it is common to have `dcast`s on - all operands to ops that must operate with the expressed type (typically - math ops prior to lowering to target-specific, quantized kernels). - }]; - let arguments = (ins quant_RealValueType:$arg); - let results = (outs quant_RealValueType:$res); -} - -def quant_StorageCastOp : quant_Op<"scast", [Pure]> { - let summary = "cast from or to a type based on the storage type and the corresponding quantized type"; - let description = [{ - A StorageCast `scast` represents a cast from or to a type based on the - storage type and a type based on a corresponding quantized type. - - This op exists to ensure type coherency for between parts of the computation - which are operating directly on an underlying storage type and those which - operate on quantized values. - - Examples from storage to quantized type: - ``` - i8 -> !quant<"uniform[i8:f32]{1.0}"> - ``` - ``` - tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - ``` - ``` - vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">> - ``` - }]; - let arguments = (ins quant_RealOrStorageValueType:$arg); - let results = (outs quant_RealOrStorageValueType:$res); - let hasFolder = 1; -} - -#endif // DIALECT_QUANT_QUANT_OPS_ diff --git a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td deleted file mode 100644 index da822d0a61deb..0000000000000 --- a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td +++ /dev/null @@ -1,74 +0,0 @@ -//===- QuantOpsBase.td - Quantization dialect base ---------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Predicates for types in the Quantization dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef DIALECT_QUANT_QUANT_OPS_BASE_ -#define DIALECT_QUANT_QUANT_OPS_BASE_ - -include "mlir/IR/OpBase.td" - -def Quantization_Dialect : Dialect { - let name = "quant"; - let cppNamespace = "::mlir::quant"; - - let useDefaultTypePrinterParser = 1; -} - -//===----------------------------------------------------------------------===// -// Quantization type definitions -//===----------------------------------------------------------------------===// - -class quant_TypedPrimitiveOrContainer : - Type.predicate, - VectorOf<[etype]>.predicate]>, - "primitive/tensor/vector of " # etype.summary>; - -// An implementation of QuantizedType. -def quant_QuantizedType : - Type($_self)">, "QuantizedType">; - -// A primitive type that can represent a real value. This is either a -// floating point value or a quantized type. -def quant_RealPrimitiveType : - Type, - "real valued primitive (float or quantized type)">; - -// A primitive type that can represent a storage value. This is either an -// integer or quantized type. -def quant_StoragePrimitiveType : - Type, - "quantized storage primitive (integer or quantized type)">; - -// A primitive or container of RealPrimitiveType. -def quant_RealValueType : - quant_TypedPrimitiveOrContainer; - -// A primitive or container of StoragePrimitiveType. -def quant_StorageValueType : - quant_TypedPrimitiveOrContainer; - -// Either a real valued or storage primitive or container type. -def quant_RealOrStorageValueType : - Type, - "real valued or storage primitive or container type">; - -// An implementation of UniformQuantizedType. -def quant_UniformQuantizedType : - DialectType($_self)">, - "UniformQuantizedType">; - -// Predicate for detecting a container or primitive of UniformQuantizedType. -def quant_UniformQuantizedValueType : - quant_TypedPrimitiveOrContainer; - -#endif // DIALECT_QUANT_QUANT_OPS_BASE_ diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..30f7c1696bdb9 --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Quant) +add_public_tablegen_target(MLIRQuantTransformsIncGen) + +add_mlir_doc(Passes QuantPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h new file mode 100644 index 0000000000000..84be2a21b34ed --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h @@ -0,0 +1,29 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace quant { + +#define GEN_PASS_DECL +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +void populateLowerQuantOpsPatterns(RewritePatternSet &patterns); + +} // namespace quant +} // namespace mlir + +#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td new file mode 100644 index 0000000000000..b25296d4db5a9 --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td @@ -0,0 +1,49 @@ +//===-- Passes.td - Arith pass definition file --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES +#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> { + let summary = "Lower quant.dcast and quant.qcast ops"; + let description = [{ + Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops + into other core dialects. + + The lowering process generates storage type casts in the form of + `quant.scast` ops to act as an interface between the original quantized + types of operands and results and their corresponding storage types used in + the generated arithmetic computations. + }]; + let dependentDialects = [ + "arith::ArithDialect", + "linalg::LinalgDialect", + "quant::QuantDialect", + "shape::ShapeDialect", + "tensor::TensorDialect" + ]; +} + +def StripFuncQuantTypes : Pass<"strip-func-quant-types"> { + let summary = "Strip quantized types from function headers"; + let description = [{ + Identify occurrences of function arguments using a quantized type and + replace them with a new value of the corresponding storage (signless + integer) type. For each converted argument, a `quant.scast` op is introduced + at the head of the function's entry block converting the new integer + argument into the original quantized value. + }]; + let dependentDialects = [ + "func::FuncDialect", + "quant::QuantDialect" + ]; +} + +#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h similarity index 93% rename from mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h rename to mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h index 367d468b2acf1..6551efc6242a6 100644 --- a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h +++ b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h @@ -34,10 +34,10 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_ -#define MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_ +#ifndef MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_ +#define MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_ -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" namespace mlir { namespace quant { @@ -64,4 +64,4 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, } // namespace quant } // namespace mlir -#endif // MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_ +#endif // MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_ diff --git a/mlir/include/mlir/Dialect/Quant/UniformSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h similarity index 97% rename from mlir/include/mlir/Dialect/Quant/UniformSupport.h rename to mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h index 4119aced4c075..6773f45069c87 100644 --- a/mlir/include/mlir/Dialect/Quant/UniformSupport.h +++ b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ -#define MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ +#ifndef MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_ +#define MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_ #include -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" #include "llvm/ADT/APFloat.h" @@ -218,4 +218,4 @@ class UniformQuantizedPerAxisValueConverter { } // namespace quant } // namespace mlir -#endif // MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ +#endif // MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_ diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 64bacd0e432fe..67b41187e5bfb 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -40,7 +40,7 @@ def Tosa_Dialect : Dialect { there will be tools to lower from the ML frameworks into TOSA. }]; - let dependentDialects = ["tensor::TensorDialect", "quant::QuantizationDialect"]; + let dependentDialects = ["tensor::TensorDialect", "quant::QuantDialect"]; let cppNamespace = "mlir::tosa"; let hasConstantMaterializer = 1; diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h index 298c97015fe2e..5e80745777b3b 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -16,8 +16,8 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Quant/FakeQuantSupport.h" -#include "mlir/Dialect/Quant/UniformSupport.h" +#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h" +#include "mlir/Dialect/Quant/Utils/UniformSupport.h" namespace mlir { namespace tosa { diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 73dccdb017ee1..7fd0432ddce1b 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -65,7 +65,7 @@ #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h" #include "mlir/Dialect/Ptr/IR/PtrDialect.h" -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" @@ -137,7 +137,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { pdl_interp::PDLInterpDialect, polynomial::PolynomialDialect, ptr::PtrDialect, - quant::QuantizationDialect, + quant::QuantDialect, ROCDL::ROCDLDialect, scf::SCFDialect, shape::ShapeDialect, diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index 1b9c1b193ace6..dd8b292a87344 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -35,6 +35,7 @@ #include "mlir/Dialect/Mesh/Transforms/Passes.h" #include "mlir/Dialect/NVGPU/Transforms/Passes.h" #include "mlir/Dialect/OpenACC/Transforms/Passes.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" @@ -82,6 +83,7 @@ inline void registerAllPasses() { memref::registerMemRefPasses(); mesh::registerMeshPasses(); ml_program::registerMLProgramPasses(); + quant::registerQuantPasses(); registerSCFPasses(); registerShapePasses(); spirv::registerSPIRVPasses(); diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp index 0a7181d8bc17c..c94dbb5692fdb 100644 --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -8,12 +8,12 @@ #include "mlir-c/Dialect/Quant.h" #include "mlir/CAPI/Registration.h" -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" using namespace mlir; -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect) //===---------------------------------------------------------------------===// // QuantizedType diff --git a/mlir/lib/Dialect/Quant/CMakeLists.txt b/mlir/lib/Dialect/Quant/CMakeLists.txt index 037bba8dcb5c9..31167e6af908b 100644 --- a/mlir/lib/Dialect/Quant/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) +add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp index c0c00fb4893cb..6a4ac310eb052 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp @@ -9,8 +9,8 @@ #include "QuantDialectBytecode.h" #include "mlir/Bytecode/BytecodeImplementation.h" -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/Diagnostics.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/SmallVector.h" @@ -31,7 +31,7 @@ static LogicalResult readDoubleAPFloat(DialectBytecodeReader &reader, return success(); } -#include "mlir/Dialect/Quant/QuantDialectBytecode.cpp.inc" +#include "mlir/Dialect/Quant/IR/QuantDialectBytecode.cpp.inc" /// This class implements the bytecode interface for the Quant dialect. struct QuantDialectBytecodeInterface : public BytecodeDialectInterface { @@ -64,6 +64,6 @@ struct QuantDialectBytecodeInterface : public BytecodeDialectInterface { }; } // namespace -void quant::detail::addBytecodeInterface(QuantizationDialect *dialect) { +void quant::detail::addBytecodeInterface(QuantDialect *dialect) { dialect->addInterfaces(); } diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h index 9e9cbf66d84d9..eef2b5bbefecc 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h +++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h @@ -15,12 +15,12 @@ #define LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H namespace mlir::quant { -class QuantizationDialect; +class QuantDialect; namespace detail { /// Add the interfaces necessary for encoding the quantization dialect /// components in bytecode. -void addBytecodeInterface(QuantizationDialect *dialect); +void addBytecodeInterface(QuantDialect *dialect); } // namespace detail } // namespace mlir::quant diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index c9a6bbc9ceeea..c584903f3a15d 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -6,44 +6,209 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/QuantOps.h" #include "QuantDialectBytecode.h" #include "TypeDetail.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Support/MathExtras.h" -#include +#include "mlir/IR/TypeUtilities.h" -using namespace mlir; -using namespace mlir::quant; -using namespace mlir::quant::detail; +#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc" -#include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc" -void QuantizationDialect::initialize() { +namespace mlir { +namespace quant { + +namespace { + +// Verify the integrity of per-axis quantization information, if present. +// +// - quantizedType +// Any quantized type. Any quantized type with no per-axis quantization is +// ignored. +// +// - containerType +// Original input or result type of the operation using the provided quantized +// type. Used to ensure that the quantized type appears within a tensor and +// that the tensor is compatible with per-axis quantization information. +// +LogicalResult verifyPerAxisQuantization(Operation *op, + QuantizedType quantizedType, + Type containerType) { + auto quantizedPerAxisType = dyn_cast(quantizedType); + if (!quantizedPerAxisType) + return success(); + + auto tensorType = dyn_cast(containerType); + if (!tensorType) + return op->emitError("scalar types may not use per-axis quantization"); + + if (!tensorType.hasRank()) + return success(); + + int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension(); + if (quantizedDimension >= tensorType.getRank()) + return op->emitError("quantized dimension must be less than tensor rank"); + + int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension); + if (quantizedDimensionSize != ShapedType::kDynamic && + quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size()) + return op->emitError( + "quantized dimension size does not match number of scales"); + + return success(); +} + +// Common verification logic for 'quant.dcast' and 'quant.qcast' ops. +// +// - quantizedType +// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'), +// whether as a primitive type or in a tensor. +// +// - floatType +// Float type used in the input ('quant.qcast') or result ('quant.dcast'), +// whether as a primitive type or in a tensor. +// +// - containerType +// Type of original input or result. +// +LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType, + FloatType floatType, Type containerType) { + if (quantizedType.getExpressedType() != floatType) + return op->emitError( + "expressed type in quantized type expected to match float type"); + + // Veriy integrity of per-axis quantization information, if present. + return verifyPerAxisQuantization(op, quantizedType, containerType); +} + +} // namespace + + +//===----------------------------------------------------------------------===// +// Dialect +//===----------------------------------------------------------------------===// + +void QuantDialect::initialize() { addTypes(); addOperations< #define GET_OP_LIST -#include "mlir/Dialect/Quant/QuantOps.cpp.inc" +#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" >(); - addBytecodeInterface(this); + detail::addBytecodeInterface(this); +} + + +//===----------------------------------------------------------------------===// +// DequantizeCastOp +//===----------------------------------------------------------------------===// + +LogicalResult DequantizeCastOp::verify() { + return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(), + getInput().getType()); +} + +OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) { + // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op + // with the value of x. Values x and y are guaranteed to be of the same type + // in this pattern. + auto srcQcastOp = getInput().getDefiningOp(); + if (!srcQcastOp) + return {}; + assert(srcQcastOp.getInput().getType() == getType()); + return srcQcastOp.getInput(); +} + +FloatType DequantizeCastOp::getFloatType() { + return cast(getElementTypeOrSelf(getResult().getType())); +} + +QuantizedType DequantizeCastOp::getQuantizedType() { + return cast(getElementTypeOrSelf(getInput().getType())); +} + + +//===----------------------------------------------------------------------===// +// QuantizeCastOp +//===----------------------------------------------------------------------===// + +LogicalResult QuantizeCastOp::verify() { + return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(), + getInput().getType()); +} + +OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) { + // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op + // with the value of x if the casts invert each other. Contrary to the folding + // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values + // x and y are not guaranteed to be of the same type here, as they may use + // different quantization parameters. + auto srcDcastOp = getInput().getDefiningOp(); + if (!srcDcastOp || srcDcastOp.getInput().getType() != getType()) + return {}; + return srcDcastOp.getInput(); +} + +FloatType QuantizeCastOp::getFloatType() { + return cast(getElementTypeOrSelf(getInput().getType())); +} + +QuantizedType QuantizeCastOp::getQuantizedType() { + return cast(getElementTypeOrSelf(getResult().getType())); +} + + +//===----------------------------------------------------------------------===// +// StorageCastOp +//===----------------------------------------------------------------------===// + +LogicalResult StorageCastOp::verify() { + auto quantizedType = getQuantizedType(); + auto integerType = getIntegerType(); + if (quantizedType.getStorageType() != integerType) + return emitError( + "storage type in quantized type expected to match integer type"); + + // Verify integrity of per-axis quantization information, if available. While + // the quantization type may appear in the input or the result, their tensor + // shapes are guaranteed to be identical at this point. + return verifyPerAxisQuantization(*this, quantizedType, getInput().getType()); } OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { - // Matches x -> [scast -> scast] -> y, replacing the second scast with the - // value of x if the casts invert each other. - auto srcScastOp = getArg().getDefiningOp(); - if (!srcScastOp || srcScastOp.getArg().getType() != getType()) - return OpFoldResult(); - return srcScastOp.getArg(); + // Matches x -> quant.scast -> quant.scast -> y, replacing the second + // quant.scast with the value of x if the casts invert each other. + auto srcScastOp = getInput().getDefiningOp(); + if (!srcScastOp || srcScastOp.getInput().getType() != getType()) + return {}; + return srcScastOp.getInput(); +} + +IntegerType StorageCastOp::getIntegerType() { + auto inputScalarType = getElementTypeOrSelf(getInput().getType()); + if (auto integerType = dyn_cast(inputScalarType)) + return integerType; + + auto resultScalarType = getElementTypeOrSelf(getResult().getType()); + return cast(resultScalarType); +} + +QuantizedType StorageCastOp::getQuantizedType() { + auto inputScalarType = getElementTypeOrSelf(getInput().getType()); + if (auto quantizedType = dyn_cast(inputScalarType)) + return quantizedType; + + auto resultScalarType = getElementTypeOrSelf(getResult().getType()); + return cast(resultScalarType); } + +} // namespace quant +} // namespace mlir + #define GET_OP_CLASSES -#include "mlir/Dialect/Quant/QuantOps.cpp.inc" +#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" + diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index c2ba9c04e8771..ac01b37a55307 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -6,9 +6,9 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/QuantTypes.h" #include "TypeDetail.h" -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -20,12 +20,28 @@ using namespace mlir; using namespace mlir::quant; using namespace mlir::quant::detail; +namespace { + +// Return the minimum scale representable in a given float type +double getMinScale(Type expressedType) { + auto floatType = cast(expressedType); + return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble(); +} + +// Return the maximum scale representable in a given float type +double getMaxScale(Type expressedType) { + auto floatType = cast(expressedType); + return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble(); +} + +} // namespace + unsigned QuantizedType::getFlags() const { return static_cast(impl)->flags; } bool QuantizedType::classof(Type type) { - return llvm::isa(type.getDialect()); + return llvm::isa(type.getDialect()); } LogicalResult @@ -73,6 +89,17 @@ int64_t QuantizedType::getStorageTypeMax() const { return static_cast(impl)->storageTypeMax; } +bool QuantizedType::hasStorageTypeBounds() const { + unsigned int integralWidth = getStorageTypeIntegralWidth(); + bool isSignedInteger = isSigned(); + int64_t defaultIntegerMin = + getDefaultMinimumForInteger(isSignedInteger, integralWidth); + int64_t defaultIntegerMax = + getDefaultMaximumForInteger(isSignedInteger, integralWidth); + return defaultIntegerMin != getStorageTypeMin() || + defaultIntegerMax != getStorageTypeMax(); +} + unsigned QuantizedType::getStorageTypeIntegralWidth() const { // NOTE: If ever supporting non-integral storage types, some other scheme // for determining the width will be needed. @@ -293,8 +320,13 @@ LogicalResult UniformQuantizedType::verifyInvariants( return emitError() << "expressed type must be floating point"; // Verify scale. + double minScale = getMinScale(expressedType); + double maxScale = getMaxScale(expressedType); if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) return emitError() << "illegal scale: " << scale; + if (scale < minScale || scale > maxScale) + return emitError() << "scale out of expressed type range [" << minScale + << ", " << maxScale << "]"; return success(); } @@ -353,11 +385,20 @@ LogicalResult UniformQuantizedPerAxisType::verifyInvariants( << scales.size() << ", " << zeroPoints.size(); // Verify scale. + double minScale = getMinScale(expressedType); + double maxScale = getMaxScale(expressedType); for (double scale : scales) { if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) return emitError() << "illegal scale: " << scale; + if (scale < minScale || scale > maxScale) + return emitError() << "scale out of expressed type range [" << minScale + << ", " << maxScale << "]"; } + // Verify quantized dimension. + if (quantizedDimension < 0) + return emitError() << "illegal quantized dimension: " << quantizedDimension; + return success(); } diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index 926a8a0aa13d5..851763d8942e8 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Location.h" @@ -317,7 +317,7 @@ static Type parseCalibratedType(DialectAsmParser &parser) { } /// Parse a type registered to this dialect. -Type QuantizationDialect::parseType(DialectAsmParser &parser) const { +Type QuantDialect::parseType(DialectAsmParser &parser) const { // All types start with an identifier that we switch on. StringRef typeNameSpelling; if (failed(parser.parseKeyword(&typeNameSpelling))) @@ -346,12 +346,7 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { } // storageTypeMin and storageTypeMax if not default. - int64_t defaultIntegerMin = - QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth); - int64_t defaultIntegerMax = - QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth); - if (defaultIntegerMin != type.getStorageTypeMin() || - defaultIntegerMax != type.getStorageTypeMax()) { + if (type.hasStorageTypeBounds()) { out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax() << ">"; } @@ -419,7 +414,7 @@ static void printCalibratedQuantizedType(CalibratedQuantizedType type, } /// Print a type registered to this dialect. -void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { +void QuantDialect::printType(Type type, DialectAsmPrinter &os) const { if (auto anyType = llvm::dyn_cast(type)) printAnyQuantizedType(anyType, os); else if (auto uniformType = llvm::dyn_cast(type)) diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..2fd4a41999d45 --- /dev/null +++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt @@ -0,0 +1,26 @@ +add_mlir_dialect_library(MLIRQuantTransforms + LowerQuantOps.cpp + StripFuncQuantTypes.cpp + + ADDITIONAL_HEADER_DIRS + {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms + + DEPENDS + MLIRQuantTransformsIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRFuncDialect + MLIRFuncTransforms + MLIRIndexDialect + MLIRIR + MLIRLinalgDialect + MLIRLinalgUtils + MLIRPass + MLIRQuantDialect + MLIRShapeDialect + MLIRTensorDialect + MLIRTransforms + MLIRTransformUtils + + ) diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp new file mode 100644 index 0000000000000..4adeb9218ff8e --- /dev/null +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -0,0 +1,676 @@ +//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms `quant.dcast` and `quant.qcast` into lower-level ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace quant { + +#define GEN_PASS_DEF_LOWERQUANTOPS +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +namespace { + +// If 'inputType' is a tensor, return its element type. If it is a scalar, +// return it as is. +Type getScalarType(Type inputType) { + if (auto tensorType = dyn_cast(inputType)) + return tensorType.getElementType(); + return inputType; +} + +// Return the shape of an input value as a list of attributes (static dimensions) +// and values (dynamic dimensions). If 'input' is a scalar, an empty list is +// returned. If 'input' is a tensor, its shape is returned. +SmallVector +getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) { + if (isa(input.getType())) + return tensor::getMixedSizes(builder, loc, input); + return {}; +} + +// If 'referenceType' is a scalar, return 'elementType' as is. If +// 'referenceType' is a tensor, return another tensor with the same shape and +// elements of type 'elementType'. +Type getScalarOrTensorType(Type elementType, Type referenceType) { + if (auto tensorType = dyn_cast(referenceType)) + return tensorType.clone(elementType); + return elementType; +} + +// Return a constant with the given value. If 'referenceType' is a tensor, a +// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a +// scalar, 'referenceShape' is ignored and a scalar constant is returned. +Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar, + Type referenceType, + ArrayRef referenceShape) { + // If the result type is a scalar, return the unmodified scalar constant. + auto tensorType = dyn_cast(referenceType); + if (!tensorType) { + assert(referenceShape.empty()); + return scalar; + } + + // Create tensor splat + auto tensorConstant = + builder.create(loc, scalar, referenceShape); + return tensorConstant; +} + +// Reshape an unranked tensor into a 1D ranked tensor. +// +// - input +// Unranked tensor. +// +// Return values: +// +// - flatInput +// 1D ranked, dynamically shaped tensor. +// +// - inputShape +// 1D extent tensor containing the shape of the original unranked input. +// +std::pair flattenUnrankedTensor(OpBuilder &builder, Location loc, + Value input) { + // Get unranked input shape and total size + auto *context = builder.getContext(); + auto shapeType = shape::getExtentTensorType(context); + auto inputShape = builder.create(loc, shapeType, input); + Value inputSize = builder.create( + loc, builder.getIndexType(), inputShape); + + // Turn input size into 1D tensor + auto flatShapeType = shape::getExtentTensorType(context, 1); + auto flatInputShape = builder.create( + loc, flatShapeType, inputSize); + + // Reshape input tensor into 1D + auto inputType = cast(input.getType()); + auto elementType = inputType.getElementType(); + auto flatInputType = + RankedTensorType::get({ShapedType::kDynamic}, elementType); + auto flatInput = builder.create( + loc, flatInputType, input, flatInputShape); + return std::make_pair(flatInput, inputShape); +} + +// Reshape an unranked tensor into a 3D ranked tensor where the central +// dimension of the result tensor corresponds to dimension 'axis' of the input +// tensor. +// +// - input +// Unranked tensor. +// +// - axis +// Index of the input dimension around which other input dimiensions will be +// collapsed. +// +// - axisSize +// Size of input dimension 'axis'. +// +// Return values: +// +// - flatInput +// 3D ranked tensor of shape [?, axisSize, ?]. +// +// - inputShape +// 1D extent tensor containing the shape of the original unranked input. +// +std::pair flattenUnrankedTensorAroundAxis(OpBuilder &builder, + Location loc, + Value input, + int64_t axis, + int64_t axisSize) { + // Get full tensor shape + auto *context = builder.getContext(); + auto indexType = builder.getIndexType(); + auto shapeType = shape::getExtentTensorType(context); + auto inputShape = builder.create(loc, shapeType, input); + + // Get shape and sizes on left and right of axis + auto axisValue = builder.create(loc, axis); + auto axisNextValue = builder.create(loc, axis + 1); + auto shapeLeft = builder.create( + loc, TypeRange{shapeType, shapeType}, inputShape, axisValue) + .getResult(0); + auto sizeLeft = builder.create( + loc, indexType, shapeLeft); + auto shapeRight = builder.create( + loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue) + .getResult(1); + auto sizeRight = builder.create( + loc, indexType, shapeRight); + + // Compute flat input shape as a 3-element 1D tensor + auto axisSizeValue = builder.create(loc, axisSize); + auto flatShapeType = shape::getExtentTensorType(context, 3); + auto flatInputShape = builder.create( + loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight}); + + // Reshape input to 3D tensor + auto inputType = cast(input.getType()); + auto elementType = inputType.getElementType(); + auto flatInputType = RankedTensorType::get( + {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType); + auto flatInput = builder.create( + loc, flatInputType, input, flatInputShape); + + return std::make_pair(flatInput, inputShape); +} + +// Reshape an input tensor into its original unranked shape. +// +// - input +// Ranked tensor. +// +// - inputShape +// 1D extent tensor. +// +Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, + Value inputShape) { + auto inputType = cast(input.getType()); + auto elementType = inputType.getElementType(); + auto unrankedType = UnrankedTensorType::get(elementType); + return builder.create(loc, unrankedType, input, inputShape); +} + +// Create a tensor constant containing all scales in a per-channel quantized +// type. Example: +// +// !quant.uniform +// +// produces +// +// %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32> +// +Value materializePerChannelScales(OpBuilder &builder, Location loc, + UniformQuantizedPerAxisType quantizedType) { + auto scales = quantizedType.getScales(); + auto expressedType = quantizedType.getExpressedType(); + auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute { + return builder.getFloatAttr(expressedType, scale); + }); + auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType); + auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs); + return builder.create(loc, tensorType, scalesAttr); +} + +// Create a tensor constant containing all zero points in a per-channel +// quantized type. Example: +// +// !quant.uniform +// +// produces +// +// %cst = arith.constant dense<[10, 20]> : tensor<2xi8> +// +Value materializePerChannelZeroPoints( + OpBuilder &builder, Location loc, + UniformQuantizedPerAxisType quantizedType) { + auto zeroPoints = quantizedType.getZeroPoints(); + auto storageType = quantizedType.getStorageType(); + auto zeroPointAttrs = llvm::map_to_vector( + zeroPoints, + [&](int64_t zeroPoint) -> Attribute { + return builder.getIntegerAttr(storageType, zeroPoint); + }); + auto tensorType = + RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType); + auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs); + return builder.create(loc, tensorType, zeroPointsAttr); +} + +// Clamp the given scalar or tensor input using the storage bounds encoded in +// the given quantized type, if present. +// +// - input +// Scalar or ranked tensor input. The element type must match the storage type +// of 'quantizedType'. +// +// - inputShape +// If 'input' is a tensor, combination of attributes/values representing its +// static/dynamic dimensions. If 'input' is a scalar, empty list. +// +// - quantizedType +// Per-axis or per-channel quantized type. +Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, + ArrayRef inputShape, + QuantizedType quantizedType) { + // If quantized type does not narrow down the storage type range, there is + // nothing to do. + if (!quantizedType.hasStorageTypeBounds()) + return input; + + // Materialize bounds + auto inputType = input.getType(); + auto storageType = quantizedType.getStorageType(); + auto storageMinScalar = builder.create( + loc, quantizedType.getStorageTypeMin(), storageType); + auto storageMaxScalar = builder.create( + loc, quantizedType.getStorageTypeMax(), storageType); + auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar, + inputType, inputShape); + auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar, + inputType, inputShape); + + // Clamp + if (quantizedType.isSigned()) { + input = builder.create(loc, input, storageMin); + input = builder.create(loc, input, storageMax); + } else { + input = builder.create(loc, input, storageMin); + input = builder.create(loc, input, storageMax); + } + return input; +} + +// Emit op 'arith.fptosi' or 'arith.fptoui'. +Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input, + Type resultType, bool isSigned) { + if (isSigned) + return builder.create(loc, resultType, input); + return builder.create(loc, resultType, input); +} + +// Emit op 'arith.sitofp' or 'arith.uitofp'. +Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, + Type resultType, bool isSigned) { + if (isSigned) + return builder.create(loc, resultType, input); + return builder.create(loc, resultType, input); +} + +// Quantize a scalar or ranked tensor value. The stored value is clamped using +// the storage bounds encoded in the given quantized type. +// +// See function 'convertRanked()' below for a description of the arguments. +Value quantizeValue(OpBuilder &builder, Location loc, Value input, + ArrayRef inputShape, Value scale, + Value zeroPoint, QuantizedType quantizedType) { + // Convert scale to tensor if necessary + auto inputType = input.getType(); + scale = getScalarOrTensorConstant( + builder, loc, scale, inputType, inputShape); + + // Scale input + auto scaledValue = builder.create(loc, input, scale); + + // Skip unnecessary computations if no zero point is given + Value storedValueFloat = scaledValue; + if (!matchPattern(zeroPoint, m_Zero())) { + // Convert zero point to tensor if necessary + zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType, + inputShape); + + // Convert zero point from storage to expressed type + zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, + scale.getType(), + quantizedType.isSigned()); + + // Add zero point to stored value + storedValueFloat = + builder.create(loc, scaledValue, zeroPoint); + } + + // Convert stored value to storage type + auto storageScalarOrTensorType = + getScalarOrTensorType(quantizedType.getStorageType(), inputType); + auto storedValueInt = convertFloatToInteger( + builder, loc, storedValueFloat, storageScalarOrTensorType, + quantizedType.isSigned()); + + // Clamp stored value it if the storage type is bound + auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt, + inputShape, quantizedType); + return storedValueClamped; +} + +// Dequantize a scalar or ranked tensor input. +// +// See function 'convertRanked()' below for a description of the arguments. +Value dequantizeValue(OpBuilder &builder, Location loc, Value input, + ArrayRef inputShape, Value scale, + Value zeroPoint, QuantizedType quantizedType) { + // Convert scale to tensor if necessary + auto inputType = input.getType(); + scale = getScalarOrTensorConstant( + builder, loc, scale, inputType, inputShape); + + // Convert stored value to float + auto result = convertIntegerToFloat( + builder, loc, input, scale.getType(), quantizedType.isSigned()); + + // Skip unnecessary computations if no zero point is given + if (!matchPattern(zeroPoint, m_Zero())) { + // Convert zero point to tensor if necessary + zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType, + inputShape); + + // Convert zero point from storage to expressed type + zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, + scale.getType(), + quantizedType.isSigned()); + + // Subtract zero point to stored value + result = builder.create(loc, result, zeroPoint); + } + + // Multiply by scale + result = builder.create(loc, result, scale); + return result; +} + +// Convert a scalar or ranked tensor input with the given scale and zero point +// values. +// +// - input +// Scalar or ranked tensor value. +// +// - inputShape +// If 'input' is a tensor, combination or attributes/values representing its +// static/dynamic dimensions. If 'input' is a scalar, empty list. +// +// - scale +// Scale as a floating-point scalar value. +// +// - zeroPoint +// Zero point as an integer scalar value. +// +// - quantizedType +// Scalar quantized type of the result ('quant.qcast') or of the input +// ('quant.dcast'). +// +Value convertRanked(OpBuilder &builder, Location loc, Operation *op, + Value input, ArrayRef inputShape, Value scale, + Value zeroPoint, QuantizedType quantizedType) { + if (isa(op)) + return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint, + quantizedType); + if (isa(op)) + return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint, + quantizedType); + llvm_unreachable("unexpected quant op"); +} + +// Convert an operation using per-layer quantization with a scalar or ranked +// tensor input. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar or ranked tensor. +// +// - quantizedType +// Per-layer quantized type. +// +Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op, + Value input, UniformQuantizedType quantizedType) { + // Create scale and zero point constants + auto expressedType = quantizedType.getExpressedType(); + auto storageType = quantizedType.getStorageType(); + auto scaleAttr = + builder.getFloatAttr(expressedType, quantizedType.getScale()); + auto scale = builder.create(loc, expressedType, scaleAttr); + auto zeroPointAttr = + builder.getIntegerAttr(storageType, quantizedType.getZeroPoint()); + auto zeroPoint = + builder.create(loc, storageType, zeroPointAttr); + + auto inputShape = getScalarOrTensorShape(builder, loc, input); + return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint, + quantizedType); +} + +// Convert an operation using per-layer quantization. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar, ranked tensor, or unranked tensor. +// +// - quantizedType +// Per-layer quantized type. +// +Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op, + Value input, UniformQuantizedType quantizedType) { + // Flatten input if unranked + bool isUnranked = isa(input.getType()); + Value inputShape; + if (isUnranked) + std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input); + + // Process ranked tensor + auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType); + + // Restore original shape if unranked + if (isUnranked) + result = restoreUnrankedTensorShape(builder, loc, result, inputShape); + + return result; +} + +// Convert an operation using per-channel quantization and a scalar or ranked +// tensor as an input. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar or ranked tensor. +// +// - quantizedType +// Per-channel quantized type. +// +Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, + Value input, + UniformQuantizedPerAxisType quantizedType, + int64_t channelAxis) { + auto *context = builder.getContext(); + + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + auto scales = materializePerChannelScales(builder, loc, quantizedType); + auto zeroPoints = + materializePerChannelZeroPoints(builder, loc, quantizedType); + + auto elementType = isa(inputType.getElementType()) + ? quantizedType.getStorageType() + : quantizedType.getExpressedType(); + auto initShape = tensor::getMixedSizes(builder, loc, input); + Value init = builder.create(loc, initShape, elementType); + + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + auto channelAxisAffineMap = AffineMap::get( + inputRank, 0, builder.getAffineDimExpr(channelAxis), context); + SmallVector indexingMaps{ + builder.getMultiDimIdentityMap(inputRank), + channelAxisAffineMap, + channelAxisAffineMap, + builder.getMultiDimIdentityMap(inputRank) + }; + auto result = builder.create( + loc, + init.getType(), // resultType + ValueRange{input, scales, zeroPoints}, // inputs + ValueRange{init}, // outputs + indexingMaps, + iteratorTypes, + [&](OpBuilder& builder, Location loc, ValueRange args) { + assert(args.size() == 4); + auto input = args[0]; + auto scale = args[1]; + auto zeroPoint = args[2]; + + auto result = convertRanked(builder, loc, op, input, {}, scale, + zeroPoint, quantizedType); + + builder.create(loc, result); + }) + .getResult(0); + + return result; +} + +// Convert an operation using per-channel quantization. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar, ranked tensor, or unranked tensor. +// +// - quantizedType +// Per-channel quantized type. +// +Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op, + Value input, + UniformQuantizedPerAxisType quantizedType) { + // Flatten unranked tensor into a 3D ranked tensor if necessary + bool isUnranked = isa(input.getType()); + int64_t channelAxis = quantizedType.getQuantizedDimension(); + int64_t channelAxisSize = (int64_t) quantizedType.getScales().size(); + Value inputShape; + if (isUnranked) { + std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis( + builder, loc, input, channelAxis, channelAxisSize); + channelAxis = 1; + } + + // Work on a ranked tensor + auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType, + channelAxis); + + // Restore original tensor shape if unranked + if (isUnranked) + result = restoreUnrankedTensorShape(builder, loc, result, inputShape); + + return result; +} + +// Convert a quantization operation. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar, ranked tensor, or unranked tensor. The element type matches +// the storage type (quant.dcast) or expressed type (quant.qcast) of +// 'quantizedType'. +// +// - quantizedType +// Per-layer or per-channel quantized type. +// +Value convertQuantized(OpBuilder &builder, Location loc, Operation *op, + Value input, Type quantizedType) { + if (auto uniformQuantizedType = dyn_cast(quantizedType)) + return convertPerLayer(builder, loc, op, input, uniformQuantizedType); + + if (auto uniformQuantizedPerAxisType = + dyn_cast(quantizedType)) + return convertPerChannel(builder, loc, op, input, + uniformQuantizedPerAxisType); + + llvm_unreachable("unexpected quantized type"); +} + +// Lowering pattern for 'quant.dcast' +struct DequantizeCastOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto input = op.getInput(); + auto quantizedType = + cast(getScalarType(op.getInput().getType())); + + // Convert quantized input to storage type + auto storageScalarOrTensorType = + getScalarOrTensorType(quantizedType.getStorageType(), input.getType()); + input = rewriter.create( + loc, storageScalarOrTensorType, input); + + auto result = convertQuantized(rewriter, loc, op, input, quantizedType); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +// Lowering pattern for 'quant.qcast' +struct QuantizeCastOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto input = op.getInput(); + auto quantizedType = getScalarType(op.getResult().getType()); + + // Flatten unranked tensor input + auto result = convertQuantized(rewriter, loc, op, input, quantizedType); + + // Cast stored value to result quantized value + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), result); + return success(); + } +}; + +struct LowerQuantOps : public impl::LowerQuantOpsBase { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateLowerQuantOpsPatterns(patterns); + + ConversionTarget target(getContext()); + target.addLegalOp(); + target.addIllegalDialect(); + target.addLegalDialect< + arith::ArithDialect, + linalg::LinalgDialect, + shape::ShapeDialect, + tensor::TensorDialect + >(); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) { + patterns.add< + DequantizeCastOpConversion, + QuantizeCastOpConversion + >(patterns.getContext()); +} + +} // namespace quant +} // namespace mlir diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp new file mode 100644 index 0000000000000..8996eff61a39c --- /dev/null +++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp @@ -0,0 +1,114 @@ +//===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Strips quantized types from function headers. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace quant { + +#define GEN_PASS_DEF_STRIPFUNCQUANTTYPES +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +namespace { + +class QuantizedTypeConverter : public TypeConverter { + + static Type convertQuantizedType(QuantizedType quantizedType) { + return quantizedType.getStorageType(); + } + + static Type convertTensorType(TensorType tensorType) { + if (auto quantizedType = dyn_cast(tensorType.getElementType())) + return tensorType.clone(convertQuantizedType(quantizedType)); + return tensorType; + } + + static Value materializeConversion(OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + return builder.create(loc, type, inputs[0]); + } + +public: + + explicit QuantizedTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion(convertQuantizedType); + addConversion(convertTensorType); + + addArgumentMaterialization(materializeConversion); + addSourceMaterialization(materializeConversion); + addTargetMaterialization(materializeConversion); + } +}; + +// Conversion pass +class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase { + + // Return whether a type is considered legal when occurring in the header of + // a function or as an operand to a 'return' op. + static bool isLegalType(Type type) { + if (auto tensorType = dyn_cast(type)) + return isLegalType(tensorType.getElementType()); + return !isa(type); + } + +public: + + void runOnOperation() override { + + auto moduleOp = cast(getOperation()); + auto* context = &getContext(); + + QuantizedTypeConverter typeConverter; + ConversionTarget target(*context); + RewritePatternSet patterns(context); + + // Mark func.func, func.return, and func.call illegal if they contain any + // quantized types. + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + + // Register conversion patterns + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + + // Apply conversion + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +} // namespace quant +} // namespace mlir + diff --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp index 8c69729824691..fb27640bfd278 100644 --- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/FakeQuantSupport.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h" using namespace mlir; using namespace mlir::quant; diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp index 408701f80444a..62c7a7128d63a 100644 --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/UniformSupport.h" +#include "mlir/Dialect/Quant/Utils/UniformSupport.h" #include "mlir/IR/BuiltinTypes.h" #include diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 03876a7c64d07..c62942e1be78e 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -11,7 +11,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 6dce3d03066c9..7f740be4efb4f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -14,7 +14,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" diff --git a/mlir/test/Dialect/Quant/canonicalize.mlir b/mlir/test/Dialect/Quant/canonicalize.mlir index 36c3eaf5e10d2..73c57e2a48212 100644 --- a/mlir/test/Dialect/Quant/canonicalize.mlir +++ b/mlir/test/Dialect/Quant/canonicalize.mlir @@ -1,24 +1,124 @@ // RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s +// CHECK-LABEL: @dcast_fold +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: return %[[ARG_0]] + +!qalias = !quant.uniform +func.func @dcast_fold(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias> + %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32> + return %1 : tensor<4xf32> +} + // ----- -// CHECK-LABEL: redundant_scast -func.func @redundant_scast() -> tensor<4xi8> { - // CHECK-NEXT: arith.constant dense<10> : tensor<4xi8> - // CHECK-NEXT: return - %cst = arith.constant dense<5> : tensor<4xi8> - %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform> - %2 = "quant.scast"(%1) : (tensor<4x!quant.uniform>) -> tensor<4xi8> - %3 = arith.addi %2, %2 : tensor<4xi8> - return %3 : tensor<4xi8> + +// CHECK-LABEL: @dcast_no_fold_source +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]] +// CHECK: %[[VAL_1:.*]] = quant.dcast %[[VAL_0]] +// CHECK: return %[[VAL_1]] + +!qalias = !quant.uniform +func.func @dcast_no_fold_source(%arg0: tensor<4xi8>) -> tensor<4xf32> { + %0 = quant.scast %arg0 : tensor<4xi8> to tensor<4x!qalias> + %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32> + return %1 : tensor<4xf32> } // ----- -// CHECK-LABEL: non_redundant_scast -func.func @non_redundant_scast() -> tensor<4x!quant.uniform> { - // CHECK-NEXT: arith.constant dense<5> : tensor<4xi8> - // CHECK-NEXT: scast - // CHECK-NEXT: return - %cst = arith.constant dense<5> : tensor<4xi8> - %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform> - return %1 : tensor<4x!quant.uniform> + +// CHECK-LABEL: @qcast_fold +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: return %[[ARG_0]] + +!qalias = !quant.uniform +func.func @qcast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> { + %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32> + %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias> + return %1 : tensor<4x!qalias> } + +// ----- + +// CHECK-LABEL: @qcast_no_fold_source +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[VAL_0:.*]] = arith.negf %[[ARG_0]] +// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]] +// CHECK: return %[[VAL_1]] + +!qalias = !quant.uniform +func.func @qcast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4x!qalias> { + %0 = arith.negf %arg0 : tensor<4xf32> + %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias> + return %1 : tensor<4x!qalias> +} + +// ----- + +// CHECK-LABEL: @qcast_no_fold_type +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[VAL_0:.*]] = quant.dcast %[[ARG_0]] +// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]] +// CHECK: return %[[VAL_1]] + +!qalias = !quant.uniform +!qalias1 = !quant.uniform +func.func @qcast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> { + %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32> + %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias1> + return %1 : tensor<4x!qalias1> +} + +// ----- + +// CHECK-LABEL: @scast_fold +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: return %[[ARG_0]] + +!qalias = !quant.uniform +func.func @scast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> { + %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8> + %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias> + return %1 : tensor<4x!qalias> +} + +// ----- + +// CHECK-LABEL: @scast_no_fold_source +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[QCAST:.*]] = quant.qcast %[[ARG_0]] +// CHECK: %[[SCAST:.*]] = quant.scast %[[QCAST]] +// CHECK: return %[[SCAST]] + +!qalias = !quant.uniform +func.func @scast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4xi8> { + %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias> + %1 = quant.scast %0 : tensor<4x!qalias> to tensor<4xi8> + return %1 : tensor<4xi8> +} + +// ----- + +// CHECK-LABEL: @scast_no_fold_type +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]] +// CHECK: %[[VAL_1:.*]] = quant.scast %[[VAL_0]] +// CHECK: return %[[VAL_1]] + +!qalias = !quant.uniform +!qalias1 = !quant.uniform +func.func @scast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> { + %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8> + %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias1> + return %1 : tensor<4x!qalias1> +} + diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir new file mode 100644 index 0000000000000..ba3a8e312d96e --- /dev/null +++ b/mlir/test/Dialect/Quant/invalid.mlir @@ -0,0 +1,258 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +func.func @dcast_invalid_input(%arg0: f32) { + // expected-error@+1 {{operand #0 must be scalar or tensor of quantized type}} + %0 = quant.dcast %arg0 : f32 to f32 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_invalid_result(%arg0: !qalias) { + // expected-error@+1 {{result #0 must be scalar or tensor of floating-point}} + %0 = quant.dcast %arg0 : !qalias to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_mismatch_scalar_tensor(%arg0: !qalias) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.dcast %arg0 : !qalias to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_mismatch_ranked_unranked_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.dcast %arg0 : tensor to tensor<*xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3x!qalias>) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_float_type_mismatch(%arg0: !qalias) { + // expected-error@+1 {{expressed type in quantized type expected to match float type}} + %0 = quant.dcast %arg0 : !qalias to f64 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_scalar(%arg0: !qalias) { + // expected-error@+1 {{scalar types may not use per-axis quantization}} + %0 = quant.dcast %arg0 : !qalias to f32 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x!qalias>) { + // expected-error@+1 {{quantized dimension must be less than tensor rank}} + %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor<2x3xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x4x!qalias>) { + // expected-error@+1 {{quantized dimension size does not match number of scales}} + %0 = quant.dcast %arg0 : tensor<2x3x4x!qalias> to tensor<2x3x4xf32> + return +} + +// ----- + +func.func @qcast_invalid_input(%arg0: f32) { + // expected-error@+1 {{result #0 must be scalar or tensor of quantized type}} + %0 = quant.qcast %arg0 : f32 to f32 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_invalid_result(%arg0: !qalias) { + // expected-error@+1 {{operand #0 must be scalar or tensor of floating-point}} + %0 = quant.qcast %arg0 : !qalias to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_mismatch_scalar_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.qcast %arg0 : tensor to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_mismatch_ranked_unranked_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.qcast %arg0 : tensor to tensor<*x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xf32>) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_float_type_mismatch(%arg0: f64) { + // expected-error@+1 {{expressed type in quantized type expected to match float type}} + %0 = quant.qcast %arg0 : f64 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_scalar(%arg0: f32) { + // expected-error@+1 {{scalar types may not use per-axis quantization}} + %0 = quant.qcast %arg0 : f32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3xf32>) { + // expected-error@+1 {{quantized dimension must be less than tensor rank}} + %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor<2x3x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3x4xf32>) { + // expected-error@+1 {{quantized dimension size does not match number of scales}} + %0 = quant.qcast %arg0 : tensor<2x3x4xf32> to tensor<2x3x4x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_invalid_input(%arg0: si32) { + // expected-error@+1 {{operand #0 must be scalar or tensor of signless integer or quantized type}} + %0 = quant.scast %arg0 : si32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_invalid_result(%arg0: !qalias) { + // expected-error@+1 {{result #0 must be scalar or tensor of signless integer or quantized type}} + %0 = quant.scast %arg0 : !qalias to si32 + return +} + +// ----- + +func.func @scast_both_integers(%arg0: i8) { + // expected-error@+1 {{input must be integer and result must be quantized, or vice versa}} + %0 = quant.scast %arg0 : i8 to i8 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_both_quantized(%arg0: !qalias) { + // expected-error@+1 {{input must be integer and result must be quantized, or vice versa}} + %0 = quant.scast %arg0 : !qalias to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_mismatch_scalar_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.scast %arg0 : tensor to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_mismatch_ranked_unranked_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.scast %arg0 : tensor to tensor<*x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xi8>) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_integer_type_mismatch(%arg0: i32) { + // expected-error@+1 {{storage type in quantized type expected to match integer type}} + %0 = quant.scast %arg0 : i32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_scalar(%arg0: i8) { + // expected-error@+1 {{scalar types may not use per-axis quantization}} + %0 = quant.scast %arg0 : i8 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3xi8>) { + // expected-error@+1 {{quantized dimension must be less than tensor rank}} + %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor<2x3x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3x4xi8>) { + // expected-error@+1 {{quantized dimension size does not match number of scales}} + %0 = quant.scast %arg0 : tensor<2x3x4xi8> to tensor<2x3x4x!qalias> + return +} + diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir new file mode 100644 index 0000000000000..6bba9f5c03772 --- /dev/null +++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir @@ -0,0 +1,511 @@ +// RUN: mlir-opt %s --lower-quant-ops --split-input-file | FileCheck %s + +// CHECK-LABEL: @dcast_per_layer_scalar +// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform to i8 + +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: return %[[EXPRESSED]] : f32 + +!qalias = !quant.uniform +func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 { + %0 = quant.dcast %arg0 : !qalias to f32 + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_scalar_unsigned +// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform to i8 + +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32 + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: return %[[EXPRESSED]] : f32 + +!qalias = !quant.uniform +func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 { + %0 = quant.dcast %arg0 : !qalias to f32 + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_0d +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor> to tensor + +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor to tensor +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor to tensor + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor +// CHECK: return %[[EXPRESSED]] : tensor + +!qalias = !quant.uniform +func.func @dcast_per_layer_0d(%arg0: tensor) -> tensor { + %0 = quant.dcast %arg0 : tensor to tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<3x?x5x!quant.uniform> to tensor<3x?x5xi8> +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[C_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_INT]], %[[C_1]] : tensor<3x?x5xi8> +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32> +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5xf32> +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8> +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32> + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<3x?x5xf32> +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32> +// CHECK: return %[[EXPRESSED]] : tensor<3x?x5xf32> + +!qalias = !quant.uniform +func.func @dcast_per_layer_ranked(%arg0: tensor<3x?x5x!qalias>) -> tensor<3x?x5xf32> { + %0 = quant.dcast %arg0 : tensor<3x?x5x!qalias> to tensor<3x?x5xf32> + return %0 : tensor<3x?x5xf32> +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_unranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform> to tensor<*xi8> +// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[STORED_INT]] : tensor<*xi8> -> tensor +// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor -> index +// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex> +// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_INT]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<1xindex>) -> tensor +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[C_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_COLLAPSED]] : tensor to tensor +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor to tensor + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor + +// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[EXPRESSED]](%[[INPUT_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32> + +!qalias = !quant.uniform +func.func @dcast_per_layer_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> { + %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> + +// CHECK-LABEL: @dcast_per_channel_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<4x?x?x5x!quant.uniform> to tensor<4x?x?x5xi8> + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8> +// CHECK: %[[C_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_1]] : tensor<4x?x?x5xi8> +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_2]] : tensor<4x?x?x5xi8> +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xf32> +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[STORED_TENSOR]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xi8>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xf32>) { +// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32): +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: linalg.yield %[[EXPRESSED]] : f32 +// CHECK: } -> tensor<4x?x?x5xf32> +// CHECK: return %[[GENERIC]] : tensor<4x?x?x5xf32> + +!qalias = !quant.uniform +func.func @dcast_per_channel_ranked(%arg0: tensor<4x?x?x5x!qalias>) -> tensor<4x?x?x5xf32> { + %0 = quant.dcast %arg0 : tensor<4x?x?x5x!qalias> to tensor<4x?x?x5xf32> + return %0 : tensor<4x?x?x5xf32> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-LABEL: @dcast_per_channel_unranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform> to tensor<*xi8> +// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[STORED_TENSOR]] : tensor<*xi8> -> tensor +// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index +// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index +// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor, index) -> (tensor, tensor) +// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor -> index +// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor, index) -> (tensor, tensor) +// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor -> index + +// CHECK: %[[NUM_CHANNELS:.*]] = arith.constant 3 : index +// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[NUM_CHANNELS]], %[[SIZE_RIGHT]] : tensor<3xindex> +// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_TENSOR]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<3xindex>) -> tensor + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8> +// CHECK: %[[C_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_2]] : tensor +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[STORED_COLLAPSED]], %[[SCALES]], %[[ZERO_POINTS]] : tensor, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor) { +// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32): +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: linalg.yield %[[EXPRESSED]] : f32 +// CHECK: } -> tensor + +// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32> + +!qalias = !quant.uniform +func.func @dcast_per_channel_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> { + %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_scalar +// CHECK-SAME: %[[ARG_0:.*]]: f32 + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : f32 to i8 + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : i8 to !quant.uniform +// CHECK: return %[[STORED_QUANT]] : !quant.uniform + +!qalias = !quant.uniform +func.func @qcast_per_layer_scalar(%arg0: f32) -> !qalias { + %0 = quant.qcast %arg0 : f32 to !qalias + return %0 : !qalias +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_scalar_bounds +// CHECK-SAME: %[[ARG_0:.*]]: f32 + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8 + +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[SCALED]] : f32 to i8 + +// CHECK-DAG: %[[C_NEG_5:.*]] = arith.constant -5 : i8 +// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8 +// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_5]] : i8 +// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8 + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform:f32, 2.000000e+00> +// CHECK: return %[[STORED_QUANT]] : !quant.uniform:f32, 2.000000e+00> + +!qalias = !quant.uniform:f32, 2.0> +func.func @qcast_per_layer_scalar_bounds(%arg0: f32) -> !qalias { + %0 = quant.qcast %arg0 : f32 to !qalias + return %0 : !qalias +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_scalar_unsigned_bounds +// CHECK-SAME: %[[ARG_0:.*]]: f32 + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8 + +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptoui %[[SCALED]] : f32 to i8 + +// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : i8 +// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8 +// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxui %[[STORED_INT]], %[[C_2]] : i8 +// CHECK: %[[STORED_CLAMPED:.*]] = arith.minui %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8 + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform:f32, 2.000000e+00> +// CHECK: return %[[STORED_QUANT]] : !quant.uniform:f32, 2.000000e+00> + +!qalias = !quant.uniform:f32, 2.0> +func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias { + %0 = quant.qcast %arg0 : f32 to !qalias + return %0 : !qalias +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_0d +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor + +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor to tensor +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor to tensor + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor to tensor> +// CHECK: return %[[STORED_QUANT]] : tensor> + +!qalias = !quant.uniform +func.func @qcast_per_layer_0d(%arg0: tensor) -> tensor { + %0 = quant.qcast %arg0 : tensor to tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x?x5xf32> + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index + +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<3x?x5xf32> +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32> +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32> + +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8> +// CHECK: %[[ZERO_POINT_TENSOR_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32> +// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_TENSOR_FLOAT]] : tensor<3x?x5xf32> +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] : tensor<3x?x5x!quant.uniform> + +!qalias = !quant.uniform +func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qalias> { + %0 = quant.qcast %arg0 : tensor<3x?x5xf32> to tensor<3x?x5x!qalias> + return %0 : tensor<3x?x5x!qalias> +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_ranked_bounds +// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32> + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]] : tensor<3x5xf32> +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_SPLAT]] : tensor<3x5xf32> + +// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<3x5xi8> +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor<3x5xi8> to tensor<3x5xf32> + +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<3x5xf32> +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<3x5xf32> to tensor<3x5xi8> + +// CHECK-DAG: %[[C_NEG_8:.*]] = arith.constant -8 : i8 +// CHECK-DAG: %[[C_7:.*]] = arith.constant 7 : i8 +// CHECK-DAG: %[[SPLAT_NEG_8:.*]] = tensor.splat %[[C_NEG_8]] : tensor<3x5xi8> +// CHECK-DAG: %[[SPLAT_7:.*]] = tensor.splat %[[C_7]] : tensor<3x5xi8> +// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[SPLAT_NEG_8]] : tensor<3x5xi8> +// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[SPLAT_7]] : tensor<3x5xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : tensor<3x5xi8> to tensor<3x5x!quant.uniform:f32, 2.000000e+00:10>> +// CHECK: return %[[STORED_QUANT]] : tensor<3x5x!quant.uniform:f32, 2.000000e+00:10>> + +!qalias = !quant.uniform:f32, 2.0:10> +func.func @qcast_per_layer_ranked_bounds(%arg0: tensor<3x5xf32>) -> tensor<3x5x!qalias> { + %0 = quant.qcast %arg0 : tensor<3x5xf32> to tensor<3x5x!qalias> + return %0 : tensor<3x5x!qalias> +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_unranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32> + +// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor +// CHECK: %[[SIZE:.*]] = shape.num_elements %[[SHAPE]] : tensor -> index +// CHECK: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[SIZE]] : tensor<1xindex> +// CHECK: %[[RANKED_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index + +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[RANKED_INPUT]], %[[C_0]] : tensor +// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor +// CHECK: %[[SCALED:.*]] = arith.divf %[[RANKED_INPUT]], %[[SCALE_SPLAT]] : tensor + +// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor to tensor +// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor to tensor + +// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[STORED_INT]](%[[SHAPE]]) : (tensor, tensor) -> tensor<*xi8> +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform> + +!qalias = !quant.uniform +func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> { + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> + return %0 : tensor<*x!qalias> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> + +// CHECK-LABEL: @qcast_per_channel_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x?x?x5xf32> + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8> + +// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<4x?x?x5xf32> +// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM_2:.*]] = tensor.dim %[[ARG_0]], %[[C_2]] : tensor<4x?x?x5xf32> +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xi8> + +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xi8>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8): +// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8 +// CHECK: linalg.yield %[[STORED_INT]] : i8 +// CHECK: } -> tensor<4x?x?x5xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x?x?x5xi8> to tensor<4x?x?x5x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] : tensor<4x?x?x5x!quant.uniform> + +!qalias = !quant.uniform +func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias> { + %0 = quant.qcast %arg0 : tensor<4x?x?x5xf32> to tensor<4x?x?x5x!qalias> + return %0 : tensor<4x?x?x5x!qalias> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-LABEL: @qcast_per_channel_ranked_bounds +// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x2x5xf32> + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<0> : tensor<2xi8> + +// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<4x2x5xi8> +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x2x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x2x5xi8>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8): +// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8 +// CHECK: %[[C_NEG_8:.*]] = arith.constant -8 : i8 +// CHECK: %[[C_7:.*]] = arith.constant 7 : i8 +// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_8]] : i8 +// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_7]] : i8 +// CHECK: linalg.yield %[[STORED_CLAMPED]] : i8 +// CHECK: } -> tensor<4x2x5xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x2x5xi8> to tensor<4x2x5x!quant.uniform:f32:1, {2.000000e+00,3.000000e+00}>> +// CHECK: return %[[STORED_QUANT]] : tensor<4x2x5x!quant.uniform:f32:1, {2.000000e+00,3.000000e+00}>> + +!qalias = !quant.uniform:f32:1, {2.0, 3.0}> +func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> { + %0 = quant.qcast %arg0 : tensor<4x2x5xf32> to tensor<4x2x5x!qalias> + return %0 : tensor<4x2x5x!qalias> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-LABEL: @qcast_per_channel_unranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32> + +// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor +// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index +// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index +// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor, index) -> (tensor, tensor) +// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor -> index +// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor, index) -> (tensor, tensor) +// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor -> index + +// CHECK: %[[CHANNEL_AXIS_SIZE:.*]] = arith.constant 3 : index +// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[CHANNEL_AXIS_SIZE]], %[[SIZE_RIGHT]] : tensor<3xindex> +// CHECK: %[[FLAT_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8> + +// CHECK: %[[C_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_0]] : tensor +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_2]] : tensor +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor + +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[FLAT_INPUT]], %[[SCALES]], %[[ZERO_POINTS]] : tensor, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8): +// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8 +// CHECK: linalg.yield %[[STORED_INT]] : i8 +// CHECK: } -> tensor + +// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor, tensor) -> tensor<*xi8> +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform> + +!qalias = !quant.uniform +func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> { + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> + return %0 : tensor<*x!qalias> +} + diff --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir new file mode 100644 index 0000000000000..4abc5830d081e --- /dev/null +++ b/mlir/test/Dialect/Quant/ops.mlir @@ -0,0 +1,151 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +!qalias = !quant.uniform +func.func @dcast_scalar(%arg0: !qalias) { + %0 = quant.dcast %arg0 : !qalias to f32 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_ranked(%arg0: tensor<2x?x4x!qalias>) { + %0 = quant.dcast %arg0 : tensor<2x?x4x!qalias> to tensor<2x?x4xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_unranked(%arg0: tensor<*x!qalias>) { + %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_static(%arg0: tensor<1x2x3x!qalias>) { + %0 = quant.dcast %arg0 : tensor<1x2x3x!qalias> to tensor<1x2x3xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_dynamic(%arg0: tensor) { + %0 = quant.dcast %arg0 : tensor to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_unranked(%arg0: tensor<*x!qalias>) { + %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_scalar(%arg0: f32) { + %0 = quant.qcast %arg0 : f32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_ranked(%arg0: tensor<2x?x4xf32>) { + %0 = quant.qcast %arg0 : tensor<2x?x4xf32> to tensor<2x?x4x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_unranked(%arg0: tensor<*xf32>) { + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_static(%arg0: tensor<1x2x3xf32>) { + %0 = quant.qcast %arg0 : tensor<1x2x3xf32> to tensor<1x2x3x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_dynamic(%arg0: tensor) { + %0 = quant.qcast %arg0 : tensor to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_unranked(%arg0: tensor<*xf32>) { + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_scalar(%arg0: i8) { + %0 = quant.scast %arg0 : i8 to !qalias + %1 = quant.scast %0 : !qalias to i8 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_ranked(%arg0: tensor<2x?x4xi8>) { + %0 = quant.scast %arg0 : tensor<2x?x4xi8> to tensor<2x?x4x!qalias> + %1 = quant.scast %0 : tensor<2x?x4x!qalias> to tensor<2x?x4xi8> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_unranked(%arg0: tensor<*xi8>) { + %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias> + %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_static(%arg0: tensor<1x2x3xi8>) { + %0 = quant.scast %arg0 : tensor<1x2x3xi8> to tensor<1x2x3x!qalias> + %1 = quant.scast %0 : tensor<1x2x3x!qalias> to tensor<1x2x3xi8> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_dynamic(%arg0: tensor) { + %0 = quant.scast %arg0 : tensor to tensor + %1 = quant.scast %0 : tensor to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_unranked(%arg0: tensor<*xi8>) { + %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias> + %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8> + return +} + + diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir index a82e8efdb1a3c..7613a344cf2b8 100644 --- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir +++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir @@ -120,3 +120,28 @@ // provided. // expected-error@+1 {{expected floating point literal}} !qalias = !quant.uniform:f32, {2.000000e+02,-19.987200e-01:1}> + +// ----- +// Illegal negative axis in per-axis quantization +// expected-error@+1 {{illegal quantized dimension: -1}} +!qalias = !quant.uniform + +// ----- +// Scale f16 underflow +// expected-error@+1 {{scale out of expressed type range}} +!qalias = !quant.uniform + +// ----- +// Scale f16 overflow +// expected-error@+1 {{scale out of expressed type range}} +!qalias = !quant.uniform + +// ----- +// Scale f16 underflow in per-axis quantization +// expected-error@+1 {{scale out of expressed type range}} +!qalias = !quant.uniform + +// ----- +// Scale f16 overflow in per-axis quantization +// expected-error@+1 {{scale out of expressed type range}} +!qalias = !quant.uniform diff --git a/mlir/test/Dialect/Quant/strip-func-quant-types.mlir b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir new file mode 100644 index 0000000000000..e5f0d4921bed3 --- /dev/null +++ b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s --strip-func-quant-types --split-input-file | FileCheck %s + +// CHECK-LABEL: @strip_operands +// CHECK-SAME: %[[ARG_0:.*]]: i8 +// CHECK-SAME: %[[ARG_1:.*]]: i16 +// CHECK-SAME: %[[ARG_2:.*]]: f32 + +// CHECK: %[[ARG_0_CAST:.*]] = quant.scast %[[ARG_1]] : i16 to !quant.uniform<{{.*}}> +// CHECK: %[[ARG_1_CAST:.*]] = quant.scast %[[ARG_0]] : i8 to !quant.uniform<{{.*}}> + +// CHECK: "test.custom_op"(%[[ARG_1_CAST]]) +// CHECK: "test.custom_op"(%[[ARG_0_CAST]]) +// CHECK: "test.custom_op"(%[[ARG_2]]) + +!qalias = !quant.uniform +!qalias1 = !quant.uniform + +func.func @strip_operands(%arg0: !qalias, %arg1: !qalias1, %arg2: f32) { + "test.custom_op"(%arg0) : (!qalias) -> tensor<4x!qalias> + "test.custom_op"(%arg1) : (!qalias1) -> tensor + "test.custom_op"(%arg2) : (f32) -> tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @strip_results +// CHECK-SAME: tensor<4xi8>, tensor, tensor<*xi8>, tensor<4xf32> + +// CHECK: %[[RESULT_0:.*]] = "test.custom_op"() +// CHECK: %[[RESULT_CAST_0:.*]] = quant.scast %[[RESULT_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8> + +// CHECK: %[[RESULT_1:.*]] = "test.custom_op"() +// CHECK: %[[RESULT_CAST_1:.*]] = quant.scast %[[RESULT_1]] : tensor> to tensor + +// CHECK: %[[RESULT_2:.*]] = "test.custom_op"() +// CHECK: %[[RESULT_CAST_2:.*]] = quant.scast %[[RESULT_2]] : tensor<*x!quant.uniform<{{.*}}>> to tensor<*xi8> + +// CHECK: %[[RESULT_3:.*]] = "test.custom_op"() + +// CHECK: return %[[RESULT_CAST_0]], %[[RESULT_CAST_1]], %[[RESULT_CAST_2]], %[[RESULT_3]] + +!qalias = !quant.uniform +!qalias1 = !quant.uniform + +func.func @strip_results() -> (tensor<4x!qalias>, tensor, tensor<*x!qalias>, tensor<4xf32>) { + %0 = "test.custom_op"() : () -> tensor<4x!qalias> + %1 = "test.custom_op"() : () -> tensor + %2 = "test.custom_op"() : () -> tensor<*x!qalias> + %3 = "test.custom_op"() : () -> tensor<4xf32> + return %0, %1, %2, %3 : tensor<4x!qalias>, tensor, tensor<*x!qalias>, tensor<4xf32> +} + +// ----- + + +// CHECK-LABEL: @callee +// CHECK-SAME: (tensor<4xi8>, tensor) -> (tensor<*xi8>, tensor<4xf32>) + +// CHECK-LABEL: @strip_call + +// CHECK: %[[OPERAND_0:.*]] = "test.custom_op"() +// CHECK: %[[OPERAND_0_CAST:.*]] = quant.scast %[[OPERAND_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8> + +// CHECK: %[[OPERAND_1:.*]] = "test.custom_op"() +// CHECK: %[[OPERAND_1_CAST:.*]] = quant.scast %[[OPERAND_1]] : tensor> to tensor + +// CHECK: %[[RESULTS:.*]]:2 = call @callee(%[[OPERAND_0_CAST]], %[[OPERAND_1_CAST]]) + +// CHECK: %[[RESULT_0_CAST:.*]] = quant.scast %[[RESULTS]]#0 : tensor<*xi8> to tensor<*x!quant.uniform<{{.*}}>> +// CHECK: "test.custom_op"(%[[RESULT_0_CAST]]) + +// CHECK: "test.custom_op"(%[[RESULTS]]#1) + +// CHECK: return + +!qalias = !quant.uniform +!qalias1 = !quant.uniform + +func.func private @callee(tensor<4x!qalias>, tensor) -> (tensor<*x!qalias>, tensor<4xf32>) + +func.func @strip_call() { + %0 = "test.custom_op"() : () -> tensor<4x!qalias> + %1 = "test.custom_op"() : () -> tensor + %2:2 = func.call @callee(%0, %1) : (tensor<4x!qalias>, tensor) -> (tensor<*x!qalias>, tensor<4xf32>) + "test.custom_op"(%2#0) : (tensor<*x!qalias>) -> () + "test.custom_op"(%2#1) : (tensor<4xf32>) -> () + return +}