From 4de6f81e366e04b51b1ad1a22e911b5412302ec6 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Mon, 8 Jul 2024 17:58:38 -0400 Subject: [PATCH 01/22] File system restructure for 'quant' dialect --- .../include/mlir/Dialect/Quant/CMakeLists.txt | 8 ++---- .../mlir/Dialect/Quant/IR/CMakeLists.txt | 6 +++++ .../Dialect/Quant/{QuantOps.h => IR/Quant.h} | 12 ++++----- .../{QuantOpsBase.td => IR/QuantBase.td} | 8 +++--- .../Quant/{ => IR}/QuantDialectBytecode.td | 0 .../mlir/Dialect/Quant/{ => IR}/QuantOps.td | 8 +++--- .../mlir/Dialect/Quant/{ => IR}/QuantTypes.h | 6 ++--- .../Dialect/Quant/Transforms/CMakeLists.txt | 5 ++++ .../mlir/Dialect/Quant/Transforms/Passes.h | 27 +++++++++++++++++++ .../mlir/Dialect/Quant/Transforms/Passes.td | 26 ++++++++++++++++++ .../Quant/{ => Utils}/FakeQuantSupport.h | 8 +++--- .../Quant/{ => Utils}/UniformSupport.h | 8 +++--- .../mlir/Dialect/Tosa/Utils/QuantUtils.h | 4 +-- mlir/include/mlir/InitAllDialects.h | 2 +- mlir/lib/CAPI/Dialect/Quant.cpp | 4 +-- .../Dialect/Quant/IR/QuantDialectBytecode.cpp | 6 ++--- mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 10 +++---- mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 4 +-- mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 4 +-- .../Dialect/Quant/Utils/FakeQuantSupport.cpp | 4 +-- .../Dialect/Quant/Utils/UniformSupport.cpp | 2 +- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 2 +- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 2 +- 23 files changed, 113 insertions(+), 53 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt rename mlir/include/mlir/Dialect/Quant/{QuantOps.h => IR/Quant.h} (69%) rename mlir/include/mlir/Dialect/Quant/{QuantOpsBase.td => IR/QuantBase.td} (93%) rename mlir/include/mlir/Dialect/Quant/{ => IR}/QuantDialectBytecode.td (100%) rename mlir/include/mlir/Dialect/Quant/{ => IR}/QuantOps.td (96%) rename mlir/include/mlir/Dialect/Quant/{ => IR}/QuantTypes.h (99%) create mode 100644 mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/Quant/Transforms/Passes.h create mode 100644 mlir/include/mlir/Dialect/Quant/Transforms/Passes.td rename mlir/include/mlir/Dialect/Quant/{ => Utils}/FakeQuantSupport.h (93%) rename mlir/include/mlir/Dialect/Quant/{ => Utils}/UniformSupport.h (97%) diff --git a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt index c08f399ee182d..9f57627c321fb 100644 --- a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt @@ -1,6 +1,2 @@ -add_mlir_dialect(QuantOps quant) -add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc) - -set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td) -mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant") -add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen) +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt new file mode 100644 index 0000000000000..c08f399ee182d --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect(QuantOps quant) +add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td) +mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant") +add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen) diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h similarity index 69% rename from mlir/include/mlir/Dialect/Quant/QuantOps.h rename to mlir/include/mlir/Dialect/Quant/IR/Quant.h index 14fb3035ab0d3..a703612d6b489 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantOps.h +++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h @@ -1,4 +1,4 @@ -//===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===// +//===- Quant.h - Quantization Ops -------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_QUANT_QUANTOPS_H_ -#define MLIR_DIALECT_QUANT_QUANTOPS_H_ +#ifndef MLIR_DIALECT_QUANT_IR_QUANT_H_ +#define MLIR_DIALECT_QUANT_IR_QUANT_H_ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -19,9 +19,9 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/Support/MathExtras.h" -#include "mlir/Dialect/Quant/QuantOpsDialect.h.inc" +#include "mlir/Dialect/Quant/IR/QuantOpsDialect.h.inc" #define GET_OP_CLASSES -#include "mlir/Dialect/Quant/QuantOps.h.inc" +#include "mlir/Dialect/Quant/IR/QuantOps.h.inc" -#endif // MLIR_DIALECT_QUANT_QUANTOPS_H_ +#endif // MLIR_DIALECT_QUANT_IR_QUANT_H_ diff --git a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td similarity index 93% rename from mlir/include/mlir/Dialect/Quant/QuantOpsBase.td rename to mlir/include/mlir/Dialect/Quant/IR/QuantBase.td index da822d0a61deb..dadca06091b1e 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td @@ -1,4 +1,4 @@ -//===- QuantOpsBase.td - Quantization dialect base ---------*- tablegen -*-===// +//===- QuantBase.td - Quantization dialect base ------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,8 +10,8 @@ // //===----------------------------------------------------------------------===// -#ifndef DIALECT_QUANT_QUANT_OPS_BASE_ -#define DIALECT_QUANT_QUANT_OPS_BASE_ +#ifndef QUANT_BASE +#define QUANT_BASE include "mlir/IR/OpBase.td" @@ -71,4 +71,4 @@ def quant_UniformQuantizedType : def quant_UniformQuantizedValueType : quant_TypedPrimitiveOrContainer; -#endif // DIALECT_QUANT_QUANT_OPS_BASE_ +#endif // QUANT_BASE diff --git a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td similarity index 100% rename from mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td rename to mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td similarity index 96% rename from mlir/include/mlir/Dialect/Quant/QuantOps.td rename to mlir/include/mlir/Dialect/Quant/IR/QuantOps.td index 7937265ce2f20..a3a0ff1608a66 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td @@ -10,10 +10,10 @@ // //===----------------------------------------------------------------------===// -#ifndef DIALECT_QUANT_QUANT_OPS_ -#define DIALECT_QUANT_QUANT_OPS_ +#ifndef QUANT_OPS +#define QUANT_OPS -include "mlir/Dialect/Quant/QuantOpsBase.td" +include "mlir/Dialect/Quant/IR/QuantBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -100,4 +100,4 @@ def quant_StorageCastOp : quant_Op<"scast", [Pure]> { let hasFolder = 1; } -#endif // DIALECT_QUANT_QUANT_OPS_ +#endif // QUANT_OPS diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h similarity index 99% rename from mlir/include/mlir/Dialect/Quant/QuantTypes.h rename to mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h index de5aed0a91a20..c020e1b46ad4e 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_QUANT_QUANTTYPES_H -#define MLIR_DIALECT_QUANT_QUANTTYPES_H +#ifndef MLIR_DIALECT_QUANT_IR_QUANTTYPES_H +#define MLIR_DIALECT_QUANT_IR_QUANTTYPES_H #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -412,4 +412,4 @@ class CalibratedQuantizedType } // namespace quant } // namespace mlir -#endif // MLIR_DIALECT_QUANT_QUANTTYPES_H +#endif // MLIR_DIALECT_QUANT_IR_QUANTTYPES_H diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..30f7c1696bdb9 --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Quant) +add_public_tablegen_target(MLIRQuantTransformsIncGen) + +add_mlir_doc(Passes QuantPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h new file mode 100644 index 0000000000000..0b7378651afa1 --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h @@ -0,0 +1,27 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace quant { + +#define GEN_PASS_DECL +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +} // namespace quant +} // namespace mlir + +#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td new file mode 100644 index 0000000000000..f511c90ec6931 --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td @@ -0,0 +1,26 @@ +//===-- Passes.td - Arith pass definition file --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES +#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def QuantLowerQuantOpsPass : Pass<"lower-quant-ops"> { + let summary = "Lower 'quant.dcast' and 'quant.qcast' ops."; + let description = [{ + Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops + into other core dialects. + + The lowering process generates storage type casts in the form of + `quant.scast` ops to convert operands and results from quantized types to + the corresponding storage type, or vice versa. + let dependentDialects = ["quant::QuantDialect"]; +} + +#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h similarity index 93% rename from mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h rename to mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h index 367d468b2acf1..6551efc6242a6 100644 --- a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h +++ b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h @@ -34,10 +34,10 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_ -#define MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_ +#ifndef MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_ +#define MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_ -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" namespace mlir { namespace quant { @@ -64,4 +64,4 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, } // namespace quant } // namespace mlir -#endif // MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_ +#endif // MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_ diff --git a/mlir/include/mlir/Dialect/Quant/UniformSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h similarity index 97% rename from mlir/include/mlir/Dialect/Quant/UniformSupport.h rename to mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h index 4119aced4c075..6773f45069c87 100644 --- a/mlir/include/mlir/Dialect/Quant/UniformSupport.h +++ b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ -#define MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ +#ifndef MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_ +#define MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_ #include -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" #include "llvm/ADT/APFloat.h" @@ -218,4 +218,4 @@ class UniformQuantizedPerAxisValueConverter { } // namespace quant } // namespace mlir -#endif // MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ +#endif // MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_ diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h index 298c97015fe2e..5e80745777b3b 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -16,8 +16,8 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Quant/FakeQuantSupport.h" -#include "mlir/Dialect/Quant/UniformSupport.h" +#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h" +#include "mlir/Dialect/Quant/Utils/UniformSupport.h" namespace mlir { namespace tosa { diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 549c26c72d8a1..75e62cda90d45 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -64,7 +64,7 @@ #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h" #include "mlir/Dialect/Ptr/IR/PtrDialect.h" -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp index 0a7181d8bc17c..b30d1de73288c 100644 --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -8,8 +8,8 @@ #include "mlir-c/Dialect/Quant.h" #include "mlir/CAPI/Registration.h" -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp index c0c00fb4893cb..0f4b755367495 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp @@ -9,8 +9,8 @@ #include "QuantDialectBytecode.h" #include "mlir/Bytecode/BytecodeImplementation.h" -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/Diagnostics.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/SmallVector.h" @@ -31,7 +31,7 @@ static LogicalResult readDoubleAPFloat(DialectBytecodeReader &reader, return success(); } -#include "mlir/Dialect/Quant/QuantDialectBytecode.cpp.inc" +#include "mlir/Dialect/Quant/IR/QuantDialectBytecode.cpp.inc" /// This class implements the bytecode interface for the Quant dialect. struct QuantDialectBytecodeInterface : public BytecodeDialectInterface { diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index c9a6bbc9ceeea..fa9725c23d643 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/QuantOps.h" #include "QuantDialectBytecode.h" #include "TypeDetail.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" @@ -24,14 +24,14 @@ using namespace mlir; using namespace mlir::quant; using namespace mlir::quant::detail; -#include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc" +#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc" void QuantizationDialect::initialize() { addTypes(); addOperations< #define GET_OP_LIST -#include "mlir/Dialect/Quant/QuantOps.cpp.inc" +#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" >(); addBytecodeInterface(this); } @@ -46,4 +46,4 @@ OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { } #define GET_OP_CLASSES -#include "mlir/Dialect/Quant/QuantOps.cpp.inc" +#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 81e3b914755be..15cde77e40afb 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -6,9 +6,9 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/QuantTypes.h" #include "TypeDetail.h" -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index 926a8a0aa13d5..c882a616f397c 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Location.h" diff --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp index 8c69729824691..fb27640bfd278 100644 --- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/FakeQuantSupport.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h" using namespace mlir; using namespace mlir::quant; diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp index 408701f80444a..62c7a7128d63a 100644 --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/UniformSupport.h" +#include "mlir/Dialect/Quant/Utils/UniformSupport.h" #include "mlir/IR/BuiltinTypes.h" #include diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 8687be075ea67..5d082ea9b1010 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -11,7 +11,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 4337787e4aead..d3aff36a763b6 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -14,7 +14,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" From d6e9adcc3d8859f9dbedbf9360d59cc91a99b9ed Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Mon, 8 Jul 2024 19:26:06 -0400 Subject: [PATCH 02/22] Successfully compiled empty pass '-lower-quant-ops' --- .../mlir/Dialect/Quant/IR/QuantBase.td | 4 +-- .../include/mlir/Dialect/Quant/IR/QuantOps.td | 2 +- .../mlir/Dialect/Quant/Transforms/Passes.td | 5 +-- .../mlir/Dialect/Tosa/IR/TosaOpBase.td | 2 +- mlir/include/mlir/InitAllDialects.h | 2 +- mlir/include/mlir/InitAllPasses.h | 2 ++ mlir/lib/CAPI/Dialect/Quant.cpp | 2 +- mlir/lib/Dialect/Quant/CMakeLists.txt | 1 + .../Dialect/Quant/IR/QuantDialectBytecode.cpp | 2 +- .../Dialect/Quant/IR/QuantDialectBytecode.h | 4 +-- mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 2 +- mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 2 +- mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 4 +-- .../Dialect/Quant/Transforms/CMakeLists.txt | 15 ++++++++ .../Quant/Transforms/LowerQuantOps.cpp | 34 +++++++++++++++++++ 15 files changed, 68 insertions(+), 15 deletions(-) create mode 100644 mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt create mode 100644 mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td index dadca06091b1e..e465d855c1986 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td @@ -15,7 +15,7 @@ include "mlir/IR/OpBase.td" -def Quantization_Dialect : Dialect { +def Quant_Dialect : Dialect { let name = "quant"; let cppNamespace = "::mlir::quant"; @@ -63,7 +63,7 @@ def quant_RealOrStorageValueType : // An implementation of UniformQuantizedType. def quant_UniformQuantizedType : - DialectType($_self)">, "UniformQuantizedType">; diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td index a3a0ff1608a66..ba282d50328da 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td @@ -22,7 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// class quant_Op traits> : - Op; + Op; //===----------------------------------------------------------------------===// // Quantization casts diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td index f511c90ec6931..e43062a98b1ea 100644 --- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td @@ -11,7 +11,7 @@ include "mlir/Pass/PassBase.td" -def QuantLowerQuantOpsPass : Pass<"lower-quant-ops"> { +def LowerQuantOps : Pass<"lower-quant-ops"> { let summary = "Lower 'quant.dcast' and 'quant.qcast' ops."; let description = [{ Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops @@ -20,7 +20,8 @@ def QuantLowerQuantOpsPass : Pass<"lower-quant-ops"> { The lowering process generates storage type casts in the form of `quant.scast` ops to convert operands and results from quantized types to the corresponding storage type, or vice versa. - let dependentDialects = ["quant::QuantDialect"]; + }]; + let dependentDialects = ["::mlir::quant::QuantDialect"]; } #endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 1412c7a2615d2..df91ba51a0594 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -40,7 +40,7 @@ def Tosa_Dialect : Dialect { there will be tools to lower from the ML frameworks into TOSA. }]; - let dependentDialects = ["tensor::TensorDialect", "quant::QuantizationDialect"]; + let dependentDialects = ["tensor::TensorDialect", "quant::QuantDialect"]; let cppNamespace = "mlir::tosa"; let hasConstantMaterializer = 1; diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 75e62cda90d45..08de36fe21db0 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -136,7 +136,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { pdl_interp::PDLInterpDialect, polynomial::PolynomialDialect, ptr::PtrDialect, - quant::QuantizationDialect, + quant::QuantDialect, ROCDL::ROCDLDialect, scf::SCFDialect, shape::ShapeDialect, diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index 1b9c1b193ace6..dd8b292a87344 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -35,6 +35,7 @@ #include "mlir/Dialect/Mesh/Transforms/Passes.h" #include "mlir/Dialect/NVGPU/Transforms/Passes.h" #include "mlir/Dialect/OpenACC/Transforms/Passes.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" @@ -82,6 +83,7 @@ inline void registerAllPasses() { memref::registerMemRefPasses(); mesh::registerMeshPasses(); ml_program::registerMLProgramPasses(); + quant::registerQuantPasses(); registerSCFPasses(); registerShapePasses(); spirv::registerSPIRVPasses(); diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp index b30d1de73288c..c94dbb5692fdb 100644 --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -13,7 +13,7 @@ using namespace mlir; -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect) //===---------------------------------------------------------------------===// // QuantizedType diff --git a/mlir/lib/Dialect/Quant/CMakeLists.txt b/mlir/lib/Dialect/Quant/CMakeLists.txt index 037bba8dcb5c9..31167e6af908b 100644 --- a/mlir/lib/Dialect/Quant/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) +add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp index 0f4b755367495..6a4ac310eb052 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp @@ -64,6 +64,6 @@ struct QuantDialectBytecodeInterface : public BytecodeDialectInterface { }; } // namespace -void quant::detail::addBytecodeInterface(QuantizationDialect *dialect) { +void quant::detail::addBytecodeInterface(QuantDialect *dialect) { dialect->addInterfaces(); } diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h index 9e9cbf66d84d9..eef2b5bbefecc 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h +++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h @@ -15,12 +15,12 @@ #define LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H namespace mlir::quant { -class QuantizationDialect; +class QuantDialect; namespace detail { /// Add the interfaces necessary for encoding the quantization dialect /// components in bytecode. -void addBytecodeInterface(QuantizationDialect *dialect); +void addBytecodeInterface(QuantDialect *dialect); } // namespace detail } // namespace mlir::quant diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index fa9725c23d643..49c05aa7f98d3 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -26,7 +26,7 @@ using namespace mlir::quant::detail; #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc" -void QuantizationDialect::initialize() { +void QuantDialect::initialize() { addTypes(); addOperations< diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 15cde77e40afb..a4829d472ecad 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -25,7 +25,7 @@ unsigned QuantizedType::getFlags() const { } bool QuantizedType::classof(Type type) { - return llvm::isa(type.getDialect()); + return llvm::isa(type.getDialect()); } LogicalResult diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index c882a616f397c..bf0f775146c1a 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -317,7 +317,7 @@ static Type parseCalibratedType(DialectAsmParser &parser) { } /// Parse a type registered to this dialect. -Type QuantizationDialect::parseType(DialectAsmParser &parser) const { +Type QuantDialect::parseType(DialectAsmParser &parser) const { // All types start with an identifier that we switch on. StringRef typeNameSpelling; if (failed(parser.parseKeyword(&typeNameSpelling))) @@ -419,7 +419,7 @@ static void printCalibratedQuantizedType(CalibratedQuantizedType type, } /// Print a type registered to this dialect. -void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { +void QuantDialect::printType(Type type, DialectAsmPrinter &os) const { if (auto anyType = llvm::dyn_cast(type)) printAnyQuantizedType(anyType, os); else if (auto uniformType = llvm::dyn_cast(type)) diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..2daea7750cfe3 --- /dev/null +++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_dialect_library(MLIRQuantTransforms + LowerQuantOps.cpp + + ADDITIONAL_HEADER_DIRS + {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms + + DEPENDS + MLIRQuantTransformsIncGen + + LINK_LIBS PUBLIC + MLIRQuantDialect + MLIRPass + MLIRTransforms + MLIRTransformUtils + ) diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp new file mode 100644 index 0000000000000..72f89326f555b --- /dev/null +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -0,0 +1,34 @@ +//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms `quant.dcast` and `quant.qcast` into lower-level ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace quant { + +#define GEN_PASS_DEF_LOWERQUANTOPS +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +namespace { + +struct LowerQuantOps : public impl::LowerQuantOpsBase { + void runOnOperation() override { + Operation *parentOp = getOperation(); + } +}; + +} // namespace + +} // namespace quant +} // namespace mlir From 7e0a7b9edb810dc80eba435cbdabf2dd1e53d9b0 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Wed, 10 Jul 2024 19:51:20 -0400 Subject: [PATCH 03/22] Lowering for 'quant.qcast' with per-layer quantization for ranked and unranked tensors --- mlir/include/mlir/Dialect/Quant/IR/Quant.h | 8 + .../include/mlir/Dialect/Quant/IR/QuantOps.td | 46 ++-- .../mlir/Dialect/Quant/IR/QuantTypes.h | 4 + .../mlir/Dialect/Quant/Transforms/Passes.h | 2 + .../mlir/Dialect/Quant/Transforms/Passes.td | 17 +- mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 53 +++- mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 11 + mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 7 +- .../Quant/Transforms/LowerQuantOps.cpp | 253 +++++++++++++++++- 9 files changed, 363 insertions(+), 38 deletions(-) diff --git a/mlir/include/mlir/Dialect/Quant/IR/Quant.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h index a703612d6b489..c5ca88ec69795 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/Quant.h +++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h @@ -21,6 +21,14 @@ #include "mlir/Dialect/Quant/IR/QuantOpsDialect.h.inc" +namespace mlir { +namespace quant { + +class UniformQuantizedType; + +} // namespace quant +} // namespace mlir + #define GET_OP_CLASSES #include "mlir/Dialect/Quant/IR/QuantOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td index ba282d50328da..48e2496203ff0 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td @@ -28,6 +28,24 @@ class quant_Op traits> : // Quantization casts //===----------------------------------------------------------------------===// +def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> { + let summary = "convert back from a quantized to quantizable (expressed) type operation"; + let description = [{ + A DequantizeCast op `dcast` represents the inverse of a `qcast`, + converting back from a quantized to quantizable (expressed) type. + + Like `qcast`s, a `dcast` is allowed to have both its operand and result + as non quantized types. This facilitates transformations and marks edges + where the computation must be carried out in the expressed type. + + Especially early in transformation, it is common to have `dcast`s on + all operands to ops that must operate with the expressed type (typically + math ops prior to lowering to target-specific, quantized kernels). + }]; + let arguments = (ins quant_RealValueType:$input); + let results = (outs quant_RealValueType:$result); +} + def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> { let summary = "convert a quantizable type to a quantized type"; let description = [{ @@ -52,26 +70,18 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> { it is legal to use a quantized representation (but is not known to be acceptable). }]; - let arguments = (ins quant_RealValueType:$arg); - let results = (outs quant_RealValueType:$res); -} + let arguments = (ins quant_RealValueType:$input); + let results = (outs quant_RealValueType:$result); -def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> { - let summary = "convert back from a quantized to quantizable (expressed) type operation"; - let description = [{ - A DequantizeCast op `dcast` represents the inverse of a `qcast`, - converting back from a quantized to quantizable (expressed) type. + let extraClassDeclaration = [{ - Like `qcast`s, a `dcast` is allowed to have both its operand and result - as non quantized types. This facilitates transformations and marks edges - where the computation must be carried out in the expressed type. + /// Return the primitive (scalar or tensor element) type of the float input. + FloatType getFloatType(); - Especially early in transformation, it is common to have `dcast`s on - all operands to ops that must operate with the expressed type (typically - math ops prior to lowering to target-specific, quantized kernels). + /// Return the primitive (scalar or tensor element) type of the quantized + /// result. + quant::UniformQuantizedType getQuantizedType(); }]; - let arguments = (ins quant_RealValueType:$arg); - let results = (outs quant_RealValueType:$res); } def quant_StorageCastOp : quant_Op<"scast", [Pure]> { @@ -95,8 +105,8 @@ def quant_StorageCastOp : quant_Op<"scast", [Pure]> { vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">> ``` }]; - let arguments = (ins quant_RealOrStorageValueType:$arg); - let results = (outs quant_RealOrStorageValueType:$res); + let arguments = (ins quant_RealOrStorageValueType:$input); + let results = (outs quant_RealOrStorageValueType:$result); let hasFolder = 1; } diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h index c020e1b46ad4e..d4c9e7f9286a6 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h @@ -114,6 +114,10 @@ class QuantizedType : public Type { /// The maximum value that storageType can take. int64_t getStorageTypeMax() const; + /// Return whether the storage type has explicit min or max boundaries + /// different from the minimum and maximum representable values. + bool hasStorageTypeBounds() const; + /// Gets the integral bit width that the underlying storage type can exactly /// represent. For integral storage types, this will just be their width. unsigned getStorageTypeIntegralWidth() const; diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h index 0b7378651afa1..84be2a21b34ed 100644 --- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h @@ -21,6 +21,8 @@ namespace quant { #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Quant/Transforms/Passes.h.inc" +void populateLowerQuantOpsPatterns(RewritePatternSet &patterns); + } // namespace quant } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td index e43062a98b1ea..19aa37f653c01 100644 --- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td @@ -11,17 +11,24 @@ include "mlir/Pass/PassBase.td" -def LowerQuantOps : Pass<"lower-quant-ops"> { - let summary = "Lower 'quant.dcast' and 'quant.qcast' ops."; +def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> { + let summary = "Lower quant.dcast and quant.qcast ops"; let description = [{ Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops into other core dialects. The lowering process generates storage type casts in the form of - `quant.scast` ops to convert operands and results from quantized types to - the corresponding storage type, or vice versa. + `quant.scast` ops to act as an interface between the original quantized + types of operands and results and their corresponding storage types used in + the generated arithmetic computations. }]; - let dependentDialects = ["::mlir::quant::QuantDialect"]; + let dependentDialects = [ + "arith::ArithDialect", + "linalg::LinalgDialect", + "quant::QuantDialect", + "scf::SCFDialect", + "tensor::TensorDialect" + ]; } #endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index 49c05aa7f98d3..2eafa348f6906 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -20,12 +20,27 @@ #include "llvm/Support/MathExtras.h" #include -using namespace mlir; -using namespace mlir::quant; -using namespace mlir::quant::detail; - #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc" + +namespace mlir { +namespace quant { + +namespace { + +Type getPrimitiveType(Type ty) { + if (auto tensorType = dyn_cast(ty)) + return tensorType.getElementType(); + return ty; +} + +} // namespace + + +//===----------------------------------------------------------------------===// +// Dialect +//===----------------------------------------------------------------------===// + void QuantDialect::initialize() { addTypes(); @@ -33,17 +48,39 @@ void QuantDialect::initialize() { #define GET_OP_LIST #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" >(); - addBytecodeInterface(this); + detail::addBytecodeInterface(this); } + +//===----------------------------------------------------------------------===// +// QuantizeCastOp +//===----------------------------------------------------------------------===// + +FloatType QuantizeCastOp::getFloatType() { + return cast(getPrimitiveType(getInput().getType())); +} + +UniformQuantizedType QuantizeCastOp::getQuantizedType() { + return cast(getPrimitiveType(getResult().getType())); +} + + +//===----------------------------------------------------------------------===// +// StorageCastOp +//===----------------------------------------------------------------------===// + OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { // Matches x -> [scast -> scast] -> y, replacing the second scast with the // value of x if the casts invert each other. - auto srcScastOp = getArg().getDefiningOp(); - if (!srcScastOp || srcScastOp.getArg().getType() != getType()) + auto srcScastOp = getInput().getDefiningOp(); + if (!srcScastOp || srcScastOp.getInput().getType() != getType()) return OpFoldResult(); - return srcScastOp.getArg(); + return srcScastOp.getInput(); } +} // namespace quant +} // namespace mlir + #define GET_OP_CLASSES #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" + diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index a4829d472ecad..2038a86bec8d6 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -72,6 +72,17 @@ int64_t QuantizedType::getStorageTypeMax() const { return static_cast(impl)->storageTypeMax; } +bool QuantizedType::hasStorageTypeBounds() const { + unsigned int integralWidth = getStorageTypeIntegralWidth(); + bool isSignedInteger = isSigned(); + int64_t defaultIntegerMin = + getDefaultMinimumForInteger(isSignedInteger, integralWidth); + int64_t defaultIntegerMax = + getDefaultMaximumForInteger(isSignedInteger, integralWidth); + return defaultIntegerMin != getStorageTypeMin() || + defaultIntegerMax != getStorageTypeMax(); +} + unsigned QuantizedType::getStorageTypeIntegralWidth() const { // NOTE: If ever supporting non-integral storage types, some other scheme // for determining the width will be needed. diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index bf0f775146c1a..851763d8942e8 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -346,12 +346,7 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { } // storageTypeMin and storageTypeMax if not default. - int64_t defaultIntegerMin = - QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth); - int64_t defaultIntegerMax = - QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth); - if (defaultIntegerMin != type.getStorageTypeMin() || - defaultIntegerMax != type.getStorageTypeMax()) { + if (type.hasStorageTypeBounds()) { out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax() << ">"; } diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 72f89326f555b..1067dc4f950d3 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -10,9 +10,16 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/Dialect/Quant/Transforms/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace quant { @@ -22,13 +29,257 @@ namespace quant { namespace { +//===----------------------------------------------------------------------===// +// DequantizeCastOp +//===----------------------------------------------------------------------===// + +class DequantizeCastOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return success(); + } +}; + + +//===----------------------------------------------------------------------===// +// QuantizeCastOp +//===----------------------------------------------------------------------===// + +// If 'containerType' is a tensor, return its element type. If it is a scalar, +// return it as is. +Type getScalarType(Type containerType) { + if (auto tensorType = dyn_cast(containerType)) + return tensorType.getElementType(); + return containerType; +} + +// Return the shape of a container as a combination of attributes (static +// dimensions) and values (dynamic dimensions). If 'container' is a scalar, +// an empty list is returned. If 'container' is a tensor, its shape is returned. +SmallVector getContainerShape(OpBuilder &builder, Location loc, + Value container) { + if (isa(container.getType())) + return tensor::getMixedSizes(builder, loc, container); + return {}; +} + +// Clone the given 'containerType' with the new given 'elementType'. If +// 'containerType' is a scalar type, there is nothing to clone, and +// 'elementType' itself is returned. If 'constainerType' is a tensor, its +// shape is cloned but the new element type is used. +Type cloneContainerType(Type containerType, Type elementType) { + if (auto tensorType = dyn_cast(containerType)) + return tensorType.clone(elementType); + return elementType; +} + +// Get a scalar or tensor constant containing the value given in 'attr'. +// If 'containerType' is a scalar, a scalar constant is returned. If +// 'containerType' is a tensor, a tensor splat of shape 'containerShape' is +// returned. +Value getContainerConstant(OpBuilder &builder, Location loc, TypedAttr attr, + Type containerType, + ArrayRef containerShape) { + // A statically shaped tensor can be created with 'arith.constant' + auto tensorType = dyn_cast(containerType); + if (tensorType && tensorType.hasStaticShape()) { + auto denseElementsAttr = DenseElementsAttr::get(tensorType, attr); + return builder.create(loc, tensorType, denseElementsAttr); + } + + // Scalar and dynamically shaped tensor containers need the scalar constant + // to be first materialized. + Value containerConstant = + builder.create(loc, attr.getType(), attr); + + // Create tensor splat if necessary + if (tensorType) { + containerConstant = + builder.create(loc, containerConstant, containerShape); + } + return containerConstant; +} + +// Calculate the size of an unranked tensor starting at dimension 'fromDim' up +// to, but not including, dimension 'toDim'. +Value getUnrankedTensorSizeRange(OpBuilder &builder, Location loc, Value input, + Value fromDim, Value toDim, Value one) { + auto loop = builder.create( + loc, + fromDim, // lowerBound + toDim, // upperBound + one, // step + one, // iterArgs + [&](OpBuilder &builder, Location loc, Value index, ValueRange args) { + Value size = builder.create(loc, input, index); + Value totalSize = builder.create(loc, args.front(), size); + builder.create(loc, totalSize); + }); + return loop.getResult(0); +} + +// Obtain the shape of an unranked tensor. This function returns a 1D tensor of +// size 'rank' and element type 'index'. +Value getUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, + Value rank) { + auto shapeType = + RankedTensorType::get({ShapedType::kDynamic}, builder.getIndexType()); + auto shape = builder.create( + loc, + shapeType, + rank, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value size = builder.create(loc, input, args.front()); + builder.create(loc, size); + }); + return shape; +} + +class QuantizeCastOpConversion : public OpConversionPattern { + + Value convertPerLayerScalarOrRanked( + OpBuilder &builder, Location loc, Value input, + UniformQuantizedType quantizedType) const { + + auto inputType = input.getType(); + auto expressedType = cast(quantizedType.getExpressedType()); + auto storageType = cast(quantizedType.getStorageType()); + auto storageContainerType = cloneContainerType(inputType, storageType); + + auto inputShape = getContainerShape(builder, loc, input); + + // Scale and zero point scalars + auto scaleAttr = builder.getFloatAttr(expressedType, quantizedType.getScale()); + auto scale = getContainerConstant(builder, loc, scaleAttr, inputType, inputShape); + auto zeroPointAttr = builder.getFloatAttr(expressedType, quantizedType.getZeroPoint()); + auto zeroPoint = getContainerConstant(builder, loc, zeroPointAttr, inputType, inputShape); + + auto scaledValue = builder.create(loc, input, scale); + auto storedValueAsExpressedType = builder.create(loc, scaledValue, zeroPoint); + + Value storedValue; + if (quantizedType.isSigned()) { + storedValue = builder.create( + loc, storageContainerType, storedValueAsExpressedType); + } else { + storedValue = builder.create( + loc, storageContainerType, storedValueAsExpressedType); + } + + // Clamp stored value if needed + if (quantizedType.hasStorageTypeBounds()) { + auto storageMinAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMin()); + auto storageMaxAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMax()); + auto storageMin = getContainerConstant(builder, loc, storageMinAttr, inputType, inputShape); + auto storageMax = getContainerConstant(builder, loc, storageMaxAttr, inputType, inputShape); + if (quantizedType.isSigned()) { + storedValue = builder.create(loc, storedValue, storageMin); + storedValue = builder.create(loc, storedValue, storageMax); + } else { + storedValue = builder.create(loc, storedValue, storageMin); + storedValue = builder.create(loc, storedValue, storageMax); + } + } + + return storedValue; + } + + Value convertPerLayerUnranked( + OpBuilder &builder, Location loc, Value input, + UniformQuantizedType quantizedType) const { + auto rank = builder.create(loc, input); + auto inputShape = getUnrankedTensorShape(builder, loc, input, rank); + auto inputType = cast(input.getType()); + + auto zero = builder.create(loc, 0); + auto one = builder.create(loc, 1); + auto inputSize = getUnrankedTensorSizeRange(builder, loc, input, zero, rank, one); + + // Compute collapsed input shape as a 1D 1-sized index tensor + auto collapsedInputShapeType = RankedTensorType::get({1}, builder.getIndexType()); + auto collapsedInputShape = builder.create( + loc, collapsedInputShapeType, inputSize); + + // Reshape input tensor into 1D + auto collapsedInputType = RankedTensorType::get({ShapedType::kDynamic}, + inputType.getElementType()); + auto collapsedInput = builder.create( + loc, collapsedInputType, input, collapsedInputShape); + + // Now we know how to convert a ranked tensor + auto collapsedStoredValue = convertPerLayerScalarOrRanked( + builder, loc, collapsedInput, quantizedType); + + // Expand stored value back to the original shape + auto expandedStoredValueType = + UnrankedTensorType::get(quantizedType.getStorageType()); + auto expandedStoredValue = builder.create( + loc, expandedStoredValueType, collapsedStoredValue, inputShape); + return expandedStoredValue; + } + +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto input = op.getInput(); + auto resultScalarType = getScalarType(op.getResult().getType()); + + // Per-layer vs per-channel quantization + Value storedValue; + if (auto quantizedType = dyn_cast(resultScalarType)) { + storedValue = isa(input.getType()) ? + convertPerLayerUnranked(rewriter, loc, input, quantizedType) : + convertPerLayerScalarOrRanked(rewriter, loc, input, quantizedType); + } else if (auto quantizedType = dyn_cast(resultScalarType)) { + // FIXM + } else { + llvm_unreachable("unexpected quantized type"); + } + + // Cast stored value to result quantized value + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), storedValue); + return success(); + } +}; + struct LowerQuantOps : public impl::LowerQuantOpsBase { void runOnOperation() override { - Operation *parentOp = getOperation(); + RewritePatternSet patterns(&getContext()); + populateLowerQuantOpsPatterns(patterns); + + ConversionTarget target(getContext()); + target.addLegalOp(); + target.addIllegalDialect(); + target.addLegalDialect< + arith::ArithDialect, + linalg::LinalgDialect, + scf::SCFDialect, + tensor::TensorDialect + >(); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); } }; } // namespace +void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) { + patterns.add< + DequantizeCastOpConversion, + QuantizeCastOpConversion + >(patterns.getContext()); +} + } // namespace quant } // namespace mlir From 8f8f6de514a07452f60442464fdd8a77a166a9f4 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Thu, 11 Jul 2024 18:01:31 -0400 Subject: [PATCH 04/22] Custom formats for 'quant' ops and use of 'shape' dialect ops for 'lower-quant-ops' --- .../include/mlir/Dialect/Quant/IR/QuantOps.td | 13 +---- .../mlir/Dialect/Quant/Transforms/Passes.td | 2 +- mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 13 ----- .../Quant/Transforms/LowerQuantOps.cpp | 57 ++++--------------- 4 files changed, 15 insertions(+), 70 deletions(-) diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td index 48e2496203ff0..7a6d270dbb6e9 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td @@ -44,6 +44,7 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> { }]; let arguments = (ins quant_RealValueType:$input); let results = (outs quant_RealValueType:$result); + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; } def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> { @@ -72,16 +73,7 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> { }]; let arguments = (ins quant_RealValueType:$input); let results = (outs quant_RealValueType:$result); - - let extraClassDeclaration = [{ - - /// Return the primitive (scalar or tensor element) type of the float input. - FloatType getFloatType(); - - /// Return the primitive (scalar or tensor element) type of the quantized - /// result. - quant::UniformQuantizedType getQuantizedType(); - }]; + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; } def quant_StorageCastOp : quant_Op<"scast", [Pure]> { @@ -107,6 +99,7 @@ def quant_StorageCastOp : quant_Op<"scast", [Pure]> { }]; let arguments = (ins quant_RealOrStorageValueType:$input); let results = (outs quant_RealOrStorageValueType:$result); + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; let hasFolder = 1; } diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td index 19aa37f653c01..56e10688b0c98 100644 --- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td @@ -26,7 +26,7 @@ def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> { "arith::ArithDialect", "linalg::LinalgDialect", "quant::QuantDialect", - "scf::SCFDialect", + "shape::ShapeDialect", "tensor::TensorDialect" ]; } diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index 2eafa348f6906..e04ca7eb7e715 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -52,19 +52,6 @@ void QuantDialect::initialize() { } -//===----------------------------------------------------------------------===// -// QuantizeCastOp -//===----------------------------------------------------------------------===// - -FloatType QuantizeCastOp::getFloatType() { - return cast(getPrimitiveType(getInput().getType())); -} - -UniformQuantizedType QuantizeCastOp::getQuantizedType() { - return cast(getPrimitiveType(getResult().getType())); -} - - //===----------------------------------------------------------------------===// // StorageCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 1067dc4f950d3..1daafdd715155 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -16,7 +16,7 @@ #include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/Dialect/Quant/Transforms/Passes.h" -#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -104,41 +104,6 @@ Value getContainerConstant(OpBuilder &builder, Location loc, TypedAttr attr, return containerConstant; } -// Calculate the size of an unranked tensor starting at dimension 'fromDim' up -// to, but not including, dimension 'toDim'. -Value getUnrankedTensorSizeRange(OpBuilder &builder, Location loc, Value input, - Value fromDim, Value toDim, Value one) { - auto loop = builder.create( - loc, - fromDim, // lowerBound - toDim, // upperBound - one, // step - one, // iterArgs - [&](OpBuilder &builder, Location loc, Value index, ValueRange args) { - Value size = builder.create(loc, input, index); - Value totalSize = builder.create(loc, args.front(), size); - builder.create(loc, totalSize); - }); - return loop.getResult(0); -} - -// Obtain the shape of an unranked tensor. This function returns a 1D tensor of -// size 'rank' and element type 'index'. -Value getUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, - Value rank) { - auto shapeType = - RankedTensorType::get({ShapedType::kDynamic}, builder.getIndexType()); - auto shape = builder.create( - loc, - shapeType, - rank, - [&](OpBuilder &builder, Location loc, ValueRange args) { - Value size = builder.create(loc, input, args.front()); - builder.create(loc, size); - }); - return shape; -} - class QuantizeCastOpConversion : public OpConversionPattern { Value convertPerLayerScalarOrRanked( @@ -191,18 +156,18 @@ class QuantizeCastOpConversion : public OpConversionPattern(loc, input); - auto inputShape = getUnrankedTensorShape(builder, loc, input, rank); + auto *context = builder.getContext(); auto inputType = cast(input.getType()); - auto zero = builder.create(loc, 0); - auto one = builder.create(loc, 1); - auto inputSize = getUnrankedTensorSizeRange(builder, loc, input, zero, rank, one); + auto shapeType = shape::getExtentTensorType(context); + auto inputShape = builder.create(loc, shapeType, input); + Value inputSize = builder.create( + loc, builder.getIndexType(), inputShape); - // Compute collapsed input shape as a 1D 1-sized index tensor - auto collapsedInputShapeType = RankedTensorType::get({1}, builder.getIndexType()); + // Turn input size into 1D tensor + auto collapsedShapeType = shape::getExtentTensorType(context, 1); auto collapsedInputShape = builder.create( - loc, collapsedInputShapeType, inputSize); + loc, collapsedShapeType, inputSize); // Reshape input tensor into 1D auto collapsedInputType = RankedTensorType::get({ShapedType::kDynamic}, @@ -210,7 +175,7 @@ class QuantizeCastOpConversion : public OpConversionPattern( loc, collapsedInputType, input, collapsedInputShape); - // Now we know how to convert a ranked tensor + // We now know how to deal with a 1D ranked input auto collapsedStoredValue = convertPerLayerScalarOrRanked( builder, loc, collapsedInput, quantizedType); @@ -262,7 +227,7 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase { target.addLegalDialect< arith::ArithDialect, linalg::LinalgDialect, - scf::SCFDialect, + shape::ShapeDialect, tensor::TensorDialect >(); From ba88a9c38a212637c3446aca204d7d8f767fad33 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Mon, 15 Jul 2024 18:08:50 -0400 Subject: [PATCH 05/22] Progress in per-channel quantization --- .../Quant/Transforms/LowerQuantOps.cpp | 237 ++++++++++++------ 1 file changed, 167 insertions(+), 70 deletions(-) diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 1daafdd715155..f5f4de807c8f0 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -104,87 +104,186 @@ Value getContainerConstant(OpBuilder &builder, Location loc, TypedAttr attr, return containerConstant; } +std::pair flattenUnrankedTensor(OpBuilder &builder, Location loc, + Value input) { + // Get unranked input shape and total size + auto *context = builder.getContext(); + auto shapeType = shape::getExtentTensorType(context); + auto inputShape = builder.create(loc, shapeType, input); + Value inputSize = builder.create( + loc, builder.getIndexType(), inputShape); + + // Turn input size into 1D tensor + auto flatShapeType = shape::getExtentTensorType(context, 1); + auto flatInputShape = builder.create( + loc, flatShapeType, inputSize); + + // Reshape input tensor into 1D + auto inputType = cast(input.getType()); + auto flatInputType = + RankedTensorType::get({ShapedType::kDynamic}, inputType.getElementType()); + auto flatInput = builder.create( + loc, flatInputType, input, flatInputShape); + return std::make_pair(flatInput, inputShape); +} + +Value restoreUnrankedTensor(OpBuilder &builder, Location loc, Value input, + Value shape) { + auto inputType = cast(input.getType()); + auto elementType = inputType.getElementType(); + auto unrankedType = UnrankedTensorType::get(elementType); + return builder.create(loc, unrankedType, input, shape); +} + +Value materializeScales(OpBuilder &builder, Location loc, + UniformQuantizedPerAxisType quantizedType) { + auto scales = quantizedType.getScales(); + auto expressedType = quantizedType.getExpressedType(); + auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute { + return builder.getFloatAttr(expressedType, scale); + }); + auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType); + auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs); + return builder.create(loc, tensorType, scalesAttr); +} + +Value materializeZeroPoints(OpBuilder &builder, Location loc, + UniformQuantizedPerAxisType quantizedType) { + auto zeroPoints = quantizedType.getZeroPoints(); + auto expressedType = quantizedType.getExpressedType(); + auto zeroPointAttrs = llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute { + return builder.getFloatAttr(expressedType, static_cast(zeroPoint)); + }); + auto tensorType = RankedTensorType::get({(int64_t) zeroPoints.size()}, expressedType); + auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs); + return builder.create(loc, tensorType, zeroPointsAttr); +} + +Value quantizeValue(OpBuilder &builder, Location loc, Value input, + ArrayRef inputShape, Value scale, + Value zeroPoint, QuantizedType quantizedType) { + auto inputType = input.getType(); + auto storageType = cast(quantizedType.getStorageType()); + auto storageContainerType = cloneContainerType(inputType, storageType); + + auto scaledValue = builder.create(loc, input, scale); + auto storedValueAsExpressedType = builder.create(loc, scaledValue, zeroPoint); + + Value storedValue; + if (quantizedType.isSigned()) { + storedValue = builder.create( + loc, storageContainerType, storedValueAsExpressedType); + } else { + storedValue = builder.create( + loc, storageContainerType, storedValueAsExpressedType); + } + + // Clamp stored value if needed + if (quantizedType.hasStorageTypeBounds()) { + auto storageMinAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMin()); + auto storageMaxAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMax()); + auto storageMin = getContainerConstant(builder, loc, storageMinAttr, inputType, inputShape); + auto storageMax = getContainerConstant(builder, loc, storageMaxAttr, inputType, inputShape); + if (quantizedType.isSigned()) { + storedValue = builder.create(loc, storedValue, storageMin); + storedValue = builder.create(loc, storedValue, storageMax); + } else { + storedValue = builder.create(loc, storedValue, storageMin); + storedValue = builder.create(loc, storedValue, storageMax); + } + } + + return storedValue; +} + class QuantizeCastOpConversion : public OpConversionPattern { - Value convertPerLayerScalarOrRanked( - OpBuilder &builder, Location loc, Value input, - UniformQuantizedType quantizedType) const { + Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input, + UniformQuantizedType quantizedType) const { auto inputType = input.getType(); auto expressedType = cast(quantizedType.getExpressedType()); - auto storageType = cast(quantizedType.getStorageType()); - auto storageContainerType = cloneContainerType(inputType, storageType); + // Create scale and zero point constants auto inputShape = getContainerShape(builder, loc, input); - - // Scale and zero point scalars auto scaleAttr = builder.getFloatAttr(expressedType, quantizedType.getScale()); auto scale = getContainerConstant(builder, loc, scaleAttr, inputType, inputShape); auto zeroPointAttr = builder.getFloatAttr(expressedType, quantizedType.getZeroPoint()); auto zeroPoint = getContainerConstant(builder, loc, zeroPointAttr, inputType, inputShape); - auto scaledValue = builder.create(loc, input, scale); - auto storedValueAsExpressedType = builder.create(loc, scaledValue, zeroPoint); + return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint, + quantizedType); + } - Value storedValue; - if (quantizedType.isSigned()) { - storedValue = builder.create( - loc, storageContainerType, storedValueAsExpressedType); - } else { - storedValue = builder.create( - loc, storageContainerType, storedValueAsExpressedType); - } + Value convertPerLayer(OpBuilder &builder, Location loc, Value input, + UniformQuantizedType quantizedType) const { + // Flatten input if unranked + bool isUnranked = isa(input.getType()); + Value shape; + if (isUnranked) + std::tie(input, shape) = flattenUnrankedTensor(builder, loc, input); - // Clamp stored value if needed - if (quantizedType.hasStorageTypeBounds()) { - auto storageMinAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMin()); - auto storageMaxAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMax()); - auto storageMin = getContainerConstant(builder, loc, storageMinAttr, inputType, inputShape); - auto storageMax = getContainerConstant(builder, loc, storageMaxAttr, inputType, inputShape); - if (quantizedType.isSigned()) { - storedValue = builder.create(loc, storedValue, storageMin); - storedValue = builder.create(loc, storedValue, storageMax); - } else { - storedValue = builder.create(loc, storedValue, storageMin); - storedValue = builder.create(loc, storedValue, storageMax); - } - } + // Process ranked tensor + auto result = convertPerLayerRanked(builder, loc, input, quantizedType); - return storedValue; + // Restore original shape if unranked + if (isUnranked) + result = restoreUnrankedTensor(builder, loc, result, shape); + + return result; } - - Value convertPerLayerUnranked( - OpBuilder &builder, Location loc, Value input, - UniformQuantizedType quantizedType) const { + + Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input, + UniformQuantizedPerAxisType quantizedType, + int32_t channelAxis) const { auto *context = builder.getContext(); - auto inputType = cast(input.getType()); - - auto shapeType = shape::getExtentTensorType(context); - auto inputShape = builder.create(loc, shapeType, input); - Value inputSize = builder.create( - loc, builder.getIndexType(), inputShape); - - // Turn input size into 1D tensor - auto collapsedShapeType = shape::getExtentTensorType(context, 1); - auto collapsedInputShape = builder.create( - loc, collapsedShapeType, inputSize); - - // Reshape input tensor into 1D - auto collapsedInputType = RankedTensorType::get({ShapedType::kDynamic}, - inputType.getElementType()); - auto collapsedInput = builder.create( - loc, collapsedInputType, input, collapsedInputShape); - - // We now know how to deal with a 1D ranked input - auto collapsedStoredValue = convertPerLayerScalarOrRanked( - builder, loc, collapsedInput, quantizedType); - - // Expand stored value back to the original shape - auto expandedStoredValueType = - UnrankedTensorType::get(quantizedType.getStorageType()); - auto expandedStoredValue = builder.create( - loc, expandedStoredValueType, collapsedStoredValue, inputShape); - return expandedStoredValue; + + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + auto scales = materializeScales(builder, loc, quantizedType); + auto zeroPoints = materializeZeroPoints(builder, loc, quantizedType); + + auto storageType = quantizedType.getStorageType(); + auto initShape = tensor::getMixedSizes(builder, loc, input); + Value init = builder.create(loc, initShape, storageType); + + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + auto channelAxisAffineMap = AffineMap::get( + inputRank, 0, builder.getAffineDimExpr(channelAxis), context); + SmallVector indexingMaps{ + builder.getMultiDimIdentityMap(inputRank), + channelAxisAffineMap, + channelAxisAffineMap, + builder.getMultiDimIdentityMap(inputRank) + }; + auto storedValue = builder.create( + loc, + init.getType(), // resultType + ValueRange{input, scales, zeroPoints}, // inputs + ValueRange{init}, // outputs + indexingMaps, + iteratorTypes, + [&](OpBuilder& builder, Location loc, ValueRange args) { + assert(args.size() == 4); + auto expressedValue = args[0]; + auto scale = args[1]; + auto zeroPoint = args[2]; + + auto storedValue = quantizeValue(builder, loc, expressedValue, {}, + scale, zeroPoint, quantizedType); + + builder.create(loc, storedValue); + }) + .getResult(0); + + return storedValue; + } + + Value convertPerChannel(OpBuilder &builder, Location loc, Value input, + UniformQuantizedPerAxisType quantizedType) const { + return convertPerChannelRanked(builder, loc, input, quantizedType, quantizedType.getQuantizedDimension()); } public: @@ -197,18 +296,16 @@ class QuantizeCastOpConversion : public OpConversionPattern(resultScalarType)) { - storedValue = isa(input.getType()) ? - convertPerLayerUnranked(rewriter, loc, input, quantizedType) : - convertPerLayerScalarOrRanked(rewriter, loc, input, quantizedType); + storedValue = convertPerLayer(rewriter, loc, input, quantizedType); } else if (auto quantizedType = dyn_cast(resultScalarType)) { - // FIXM + storedValue = convertPerChannel(rewriter, loc, input, quantizedType); } else { llvm_unreachable("unexpected quantized type"); } - + // Cast stored value to result quantized value rewriter.replaceOpWithNewOp( op, op.getResult().getType(), storedValue); From 6b9caccad131a0d2e6cf22419b0486cf01f1b28b Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Tue, 16 Jul 2024 10:15:24 -0400 Subject: [PATCH 06/22] Support for unranked tensor in per-channel quantization --- .../Quant/Transforms/LowerQuantOps.cpp | 72 +++++++++++++++++-- 1 file changed, 66 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index f5f4de807c8f0..49ef8ea354070 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -120,10 +120,52 @@ std::pair flattenUnrankedTensor(OpBuilder &builder, Location loc, // Reshape input tensor into 1D auto inputType = cast(input.getType()); + auto elementType = inputType.getElementType(); auto flatInputType = - RankedTensorType::get({ShapedType::kDynamic}, inputType.getElementType()); + RankedTensorType::get({ShapedType::kDynamic}, elementType); + auto flatInput = builder.create( + loc, flatInputType, input, flatInputShape); + return std::make_pair(flatInput, inputShape); +} + +std::pair flattenUnrankedTensorAroundAxis(OpBuilder &builder, + Location loc, + Value input, + int64_t axis) { + // Get full tensor shape + auto *context = builder.getContext(); + auto indexType = builder.getIndexType(); + auto shapeType = shape::getExtentTensorType(context); + auto inputShape = builder.create(loc, shapeType, input); + + // Get shape and sizes on left and right of axis + auto axisValue = builder.create(loc, axis); + auto axisNextValue = builder.create(loc, axis + 1); + auto shapeLeft = builder.create( + loc, TypeRange{shapeType, shapeType}, inputShape, axisValue) + .getResult(0); + auto sizeLeft = builder.create( + loc, indexType, shapeLeft); + auto shapeRight = builder.create( + loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue) + .getResult(1); + auto sizeRight = builder.create( + loc, indexType, shapeRight); + Value axisSize = builder.create(loc, input, axisValue); + + // Compute flat input shape as a 3-element 1D tensor + auto flatShapeType = shape::getExtentTensorType(context, 3); + auto flatInputShape = builder.create( + loc, flatShapeType, ValueRange{sizeLeft, axisSize, sizeRight}); + + // Reshape input to 3D tensor + auto inputType = cast(input.getType()); + auto elementType = inputType.getElementType(); + SmallVector flatInputDims(3, ShapedType::kDynamic); + auto flatInputType = RankedTensorType::get(flatInputDims, elementType); auto flatInput = builder.create( loc, flatInputType, input, flatInputShape); + return std::make_pair(flatInput, inputShape); } @@ -219,23 +261,23 @@ class QuantizeCastOpConversion : public OpConversionPattern(input.getType()); - Value shape; + Value inputShape; if (isUnranked) - std::tie(input, shape) = flattenUnrankedTensor(builder, loc, input); + std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input); // Process ranked tensor auto result = convertPerLayerRanked(builder, loc, input, quantizedType); // Restore original shape if unranked if (isUnranked) - result = restoreUnrankedTensor(builder, loc, result, shape); + result = restoreUnrankedTensor(builder, loc, result, inputShape); return result; } Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input, UniformQuantizedPerAxisType quantizedType, - int32_t channelAxis) const { + int64_t channelAxis) const { auto *context = builder.getContext(); auto inputType = cast(input.getType()); @@ -283,7 +325,25 @@ class QuantizeCastOpConversion : public OpConversionPattern(input.getType()); + int64_t channelAxis = quantizedType.getQuantizedDimension(); + Value inputShape; + if (isUnranked) { + std::tie(input, inputShape) = + flattenUnrankedTensorAroundAxis(builder, loc, input, channelAxis); + channelAxis = 1; + } + + // Work on a ranked tensor + auto result = convertPerChannelRanked(builder, loc, input, quantizedType, + channelAxis); + + // Restore original tensor shape if unranked + if (isUnranked) + result = restoreUnrankedTensor(builder, loc, result, inputShape); + + return result; } public: From b2fee689c74012651662d4edc8c00988aea5344b Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Tue, 16 Jul 2024 15:24:15 -0400 Subject: [PATCH 07/22] Refactored 'quant.qcast' lowering. Ready to begin 'quant.dcast' --- .../Quant/Transforms/LowerQuantOps.cpp | 270 ++++++++++-------- 1 file changed, 158 insertions(+), 112 deletions(-) diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 49ef8ea354070..68a4328128292 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -49,59 +49,50 @@ class DequantizeCastOpConversion : public OpConversionPattern(containerType)) +Type getScalarType(Type inputType) { + if (auto tensorType = dyn_cast(inputType)) return tensorType.getElementType(); - return containerType; + return inputType; } -// Return the shape of a container as a combination of attributes (static -// dimensions) and values (dynamic dimensions). If 'container' is a scalar, -// an empty list is returned. If 'container' is a tensor, its shape is returned. -SmallVector getContainerShape(OpBuilder &builder, Location loc, - Value container) { - if (isa(container.getType())) - return tensor::getMixedSizes(builder, loc, container); +// Return the shape of an input value as a list of attributes (static dimensions) +// and values (dynamic dimensions). If 'input' is a scalar, an empty list is +// returned. If 'input' is a tensor, its shape is returned. +SmallVector +getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) { + if (isa(input.getType())) + return tensor::getMixedSizes(builder, loc, input); return {}; } -// Clone the given 'containerType' with the new given 'elementType'. If -// 'containerType' is a scalar type, there is nothing to clone, and -// 'elementType' itself is returned. If 'constainerType' is a tensor, its -// shape is cloned but the new element type is used. -Type cloneContainerType(Type containerType, Type elementType) { - if (auto tensorType = dyn_cast(containerType)) +// If 'referenceType' is a scalar, return 'elementType' as is. If +// 'referenceType' is a tensor, return another tensor with the same shape and +// elements of type 'elementType'. +Type getScalarOrTensorType(Type elementType, Type referenceType) { + if (auto tensorType = dyn_cast(referenceType)) return tensorType.clone(elementType); return elementType; } -// Get a scalar or tensor constant containing the value given in 'attr'. -// If 'containerType' is a scalar, a scalar constant is returned. If -// 'containerType' is a tensor, a tensor splat of shape 'containerShape' is -// returned. -Value getContainerConstant(OpBuilder &builder, Location loc, TypedAttr attr, - Type containerType, - ArrayRef containerShape) { - // A statically shaped tensor can be created with 'arith.constant' - auto tensorType = dyn_cast(containerType); - if (tensorType && tensorType.hasStaticShape()) { - auto denseElementsAttr = DenseElementsAttr::get(tensorType, attr); - return builder.create(loc, tensorType, denseElementsAttr); +// Return a constant with the given value. If 'referenceType' is a tensor, a +// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a +// scalar, 'referenceShape' is ignored and a scalar constant is returned. +Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar, + Type referenceType, + ArrayRef referenceShape) { + // If the result type is a scalar, return the unmodified scalar constant. + auto tensorType = dyn_cast(referenceType); + if (!tensorType) { + assert(referenceShape.empty()); + return scalar; } - // Scalar and dynamically shaped tensor containers need the scalar constant - // to be first materialized. - Value containerConstant = - builder.create(loc, attr.getType(), attr); - - // Create tensor splat if necessary - if (tensorType) { - containerConstant = - builder.create(loc, containerConstant, containerShape); - } - return containerConstant; + // Create tensor splat + auto tensorConstant = + builder.create(loc, scalar, referenceShape); + return tensorConstant; } std::pair flattenUnrankedTensor(OpBuilder &builder, Location loc, @@ -131,7 +122,8 @@ std::pair flattenUnrankedTensor(OpBuilder &builder, Location loc, std::pair flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input, - int64_t axis) { + int64_t axis, + int64_t axisSize) { // Get full tensor shape auto *context = builder.getContext(); auto indexType = builder.getIndexType(); @@ -151,34 +143,34 @@ std::pair flattenUnrankedTensorAroundAxis(OpBuilder &builder, .getResult(1); auto sizeRight = builder.create( loc, indexType, shapeRight); - Value axisSize = builder.create(loc, input, axisValue); // Compute flat input shape as a 3-element 1D tensor + auto axisSizeValue = builder.create(loc, axisSize); auto flatShapeType = shape::getExtentTensorType(context, 3); auto flatInputShape = builder.create( - loc, flatShapeType, ValueRange{sizeLeft, axisSize, sizeRight}); + loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight}); // Reshape input to 3D tensor auto inputType = cast(input.getType()); auto elementType = inputType.getElementType(); - SmallVector flatInputDims(3, ShapedType::kDynamic); - auto flatInputType = RankedTensorType::get(flatInputDims, elementType); + auto flatInputType = RankedTensorType::get( + {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType); auto flatInput = builder.create( loc, flatInputType, input, flatInputShape); return std::make_pair(flatInput, inputShape); } -Value restoreUnrankedTensor(OpBuilder &builder, Location loc, Value input, - Value shape) { - auto inputType = cast(input.getType()); +Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, + Value inputShape) { + auto inputType = cast(input.getType()); auto elementType = inputType.getElementType(); auto unrankedType = UnrankedTensorType::get(elementType); - return builder.create(loc, unrankedType, input, shape); + return builder.create(loc, unrankedType, input, inputShape); } -Value materializeScales(OpBuilder &builder, Location loc, - UniformQuantizedPerAxisType quantizedType) { +Value materializePerChannelScales(OpBuilder &builder, Location loc, + UniformQuantizedPerAxisType quantizedType) { auto scales = quantizedType.getScales(); auto expressedType = quantizedType.getExpressedType(); auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute { @@ -189,52 +181,100 @@ Value materializeScales(OpBuilder &builder, Location loc, return builder.create(loc, tensorType, scalesAttr); } -Value materializeZeroPoints(OpBuilder &builder, Location loc, +Value materializePerChannelZeroPoints(OpBuilder &builder, Location loc, UniformQuantizedPerAxisType quantizedType) { auto zeroPoints = quantizedType.getZeroPoints(); - auto expressedType = quantizedType.getExpressedType(); - auto zeroPointAttrs = llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute { - return builder.getFloatAttr(expressedType, static_cast(zeroPoint)); - }); - auto tensorType = RankedTensorType::get({(int64_t) zeroPoints.size()}, expressedType); + auto storageType = quantizedType.getStorageType(); + auto zeroPointAttrs = llvm::map_to_vector( + zeroPoints, + [&](int64_t zeroPoint) -> Attribute { + return builder.getIntegerAttr(storageType, zeroPoint); + }); + auto tensorType = + RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType); auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs); return builder.create(loc, tensorType, zeroPointsAttr); } -Value quantizeValue(OpBuilder &builder, Location loc, Value input, - ArrayRef inputShape, Value scale, - Value zeroPoint, QuantizedType quantizedType) { - auto inputType = input.getType(); - auto storageType = cast(quantizedType.getStorageType()); - auto storageContainerType = cloneContainerType(inputType, storageType); - - auto scaledValue = builder.create(loc, input, scale); - auto storedValueAsExpressedType = builder.create(loc, scaledValue, zeroPoint); +Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, + ArrayRef inputShape, + QuantizedType quantizedType) { + // If quantized type does not narrow down the storage type range, there is + // nothing to do. + if (!quantizedType.hasStorageTypeBounds()) + return input; - Value storedValue; + // Materialize bounds + auto inputType = input.getType(); + auto storageType = quantizedType.getStorageType(); + auto storageMinScalar = builder.create( + loc, quantizedType.getStorageTypeMin(), storageType); + auto storageMaxScalar = builder.create( + loc, quantizedType.getStorageTypeMax(), storageType); + auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar, + inputType, inputShape); + auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar, + inputType, inputShape); + + // Clamp if (quantizedType.isSigned()) { - storedValue = builder.create( - loc, storageContainerType, storedValueAsExpressedType); + input = builder.create(loc, input, storageMin); + input = builder.create(loc, input, storageMax); } else { - storedValue = builder.create( - loc, storageContainerType, storedValueAsExpressedType); + input = builder.create(loc, input, storageMin); + input = builder.create(loc, input, storageMax); } + return input; +} - // Clamp stored value if needed - if (quantizedType.hasStorageTypeBounds()) { - auto storageMinAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMin()); - auto storageMaxAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMax()); - auto storageMin = getContainerConstant(builder, loc, storageMinAttr, inputType, inputShape); - auto storageMax = getContainerConstant(builder, loc, storageMaxAttr, inputType, inputShape); - if (quantizedType.isSigned()) { - storedValue = builder.create(loc, storedValue, storageMin); - storedValue = builder.create(loc, storedValue, storageMax); - } else { - storedValue = builder.create(loc, storedValue, storageMin); - storedValue = builder.create(loc, storedValue, storageMax); - } - } +Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input, + Type resultType, bool isSigned) { + if (isSigned) + return builder.create(loc, resultType, input); + return builder.create(loc, resultType, input); +} + +Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, + Type resultType, bool isSigned) { + if (isSigned) + return builder.create(loc, resultType, input); + return builder.create(loc, resultType, input); +} +// Quantize a floating-point input using the given scale, input shape, and +// storage type bounds in the given quantized type. +Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input, + ArrayRef inputShape, Value scale, + Value zeroPoint, QuantizedType quantizedType) { + // Convert scale and zero point to tensors if necessary + auto inputType = input.getType(); + scale = getScalarOrTensorConstant( + builder, loc, scale, inputType, inputShape); + zeroPoint = getScalarOrTensorConstant( + builder, loc, zeroPoint, inputType, inputShape); + + // Convert zero point from storage to expressed type + auto expressedScalarOrTensorType = + getScalarOrTensorType(quantizedType.getExpressedType(), inputType); + zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, + expressedScalarOrTensorType, + quantizedType.isSigned()); + + // Scale input and add zero point + auto scaledValue = builder.create(loc, input, scale); + auto storedValueAsExpressedType = + builder.create(loc, scaledValue, zeroPoint); + + // Convert to storage type + auto storageScalarOrTensorType = + getScalarOrTensorType(quantizedType.getStorageType(), inputType); + auto storedValue = convertFloatToInteger( + builder, loc, storedValueAsExpressedType, storageScalarOrTensorType, + quantizedType.isSigned()); + + // Clamp stored value it if the storage type is bound + storedValue = + clampScalarOrTensor(builder, loc, storedValue, inputShape, quantizedType); return storedValue; } @@ -243,18 +283,21 @@ class QuantizeCastOpConversion : public OpConversionPattern(quantizedType.getExpressedType()); - // Create scale and zero point constants - auto inputShape = getContainerShape(builder, loc, input); - auto scaleAttr = builder.getFloatAttr(expressedType, quantizedType.getScale()); - auto scale = getContainerConstant(builder, loc, scaleAttr, inputType, inputShape); - auto zeroPointAttr = builder.getFloatAttr(expressedType, quantizedType.getZeroPoint()); - auto zeroPoint = getContainerConstant(builder, loc, zeroPointAttr, inputType, inputShape); - - return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint, - quantizedType); + auto expressedType = quantizedType.getExpressedType(); + auto storageType = quantizedType.getStorageType(); + auto scaleAttr = + builder.getFloatAttr(expressedType, quantizedType.getScale()); + auto scale = + builder.create(loc, expressedType, scaleAttr); + auto zeroPointAttr = + builder.getIntegerAttr(storageType, quantizedType.getZeroPoint()); + auto zeroPoint = + builder.create(loc, storageType, zeroPointAttr); + + auto inputShape = getScalarOrTensorShape(builder, loc, input); + return quantizeScalarOrTensor(builder, loc, input, inputShape, scale, + zeroPoint, quantizedType); } Value convertPerLayer(OpBuilder &builder, Location loc, Value input, @@ -270,7 +313,7 @@ class QuantizeCastOpConversion : public OpConversionPattern(input.getType()); auto inputRank = inputType.getRank(); - auto scales = materializeScales(builder, loc, quantizedType); - auto zeroPoints = materializeZeroPoints(builder, loc, quantizedType); + auto scales = materializePerChannelScales(builder, loc, quantizedType); + auto zeroPoints = + materializePerChannelZeroPoints(builder, loc, quantizedType); auto storageType = quantizedType.getStorageType(); auto initShape = tensor::getMixedSizes(builder, loc, input); @@ -313,10 +357,10 @@ class QuantizeCastOpConversion : public OpConversionPattern(loc, storedValue); + builder.create(loc, result); }) .getResult(0); @@ -325,13 +369,14 @@ class QuantizeCastOpConversion : public OpConversionPattern(input.getType()); int64_t channelAxis = quantizedType.getQuantizedDimension(); + int64_t channelAxisSize = (int64_t) quantizedType.getScales().size(); Value inputShape; if (isUnranked) { - std::tie(input, inputShape) = - flattenUnrankedTensorAroundAxis(builder, loc, input, channelAxis); + std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis( + builder, loc, input, channelAxis, channelAxisSize); channelAxis = 1; } @@ -341,7 +386,7 @@ class QuantizeCastOpConversion : public OpConversionPattern(resultScalarType)) { - storedValue = convertPerLayer(rewriter, loc, input, quantizedType); - } else if (auto quantizedType = dyn_cast(resultScalarType)) { - storedValue = convertPerChannel(rewriter, loc, input, quantizedType); + result = convertPerLayer(rewriter, loc, input, quantizedType); + } else if (auto quantizedType = + dyn_cast(resultScalarType)) { + result = convertPerChannel(rewriter, loc, input, quantizedType); } else { - llvm_unreachable("unexpected quantized type"); + llvm_unreachable("unexpected uniform quantized type"); } // Cast stored value to result quantized value rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), storedValue); + op, op.getResult().getType(), result); return success(); } }; From f032c472134642dc44ab6134d2e2646ca13ca341 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Wed, 17 Jul 2024 11:31:20 -0400 Subject: [PATCH 08/22] Per-layer quantization unit tests --- .../Quant/Transforms/LowerQuantOps.cpp | 50 +++--- mlir/test/Dialect/Quant/lower-quant-ops.mlir | 165 ++++++++++++++++++ 2 files changed, 194 insertions(+), 21 deletions(-) create mode 100644 mlir/test/Dialect/Quant/lower-quant-ops.mlir diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 68a4328128292..e6942899cd638 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Quant/Transforms/Passes.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -181,8 +182,9 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc, return builder.create(loc, tensorType, scalesAttr); } -Value materializePerChannelZeroPoints(OpBuilder &builder, Location loc, - UniformQuantizedPerAxisType quantizedType) { +Value materializePerChannelZeroPoints( + OpBuilder &builder, Location loc, + UniformQuantizedPerAxisType quantizedType) { auto zeroPoints = quantizedType.getZeroPoints(); auto storageType = quantizedType.getStorageType(); auto zeroPointAttrs = llvm::map_to_vector( @@ -246,36 +248,42 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input, ArrayRef inputShape, Value scale, Value zeroPoint, QuantizedType quantizedType) { - // Convert scale and zero point to tensors if necessary + // Convert scale to tensor if necessary auto inputType = input.getType(); scale = getScalarOrTensorConstant( builder, loc, scale, inputType, inputShape); - zeroPoint = getScalarOrTensorConstant( - builder, loc, zeroPoint, inputType, inputShape); - // Convert zero point from storage to expressed type - auto expressedScalarOrTensorType = - getScalarOrTensorType(quantizedType.getExpressedType(), inputType); - zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, - expressedScalarOrTensorType, - quantizedType.isSigned()); - - // Scale input and add zero point + // Scale input auto scaledValue = builder.create(loc, input, scale); - auto storedValueAsExpressedType = - builder.create(loc, scaledValue, zeroPoint); - // Convert to storage type + // Skip unnecessary computations if no zero point is given + Value storedValueFloat = scaledValue; + if (matchPattern(zeroPoint, m_NonZero())) { + // Convert zero point to tensor if necessary + zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType, + inputShape); + + // Convert zero point from storage to expressed type + zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, + scale.getType(), + quantizedType.isSigned()); + + // Add zero point to stored value + storedValueFloat = + builder.create(loc, scaledValue, zeroPoint); + } + + // Convert stored value to storage type auto storageScalarOrTensorType = getScalarOrTensorType(quantizedType.getStorageType(), inputType); - auto storedValue = convertFloatToInteger( - builder, loc, storedValueAsExpressedType, storageScalarOrTensorType, + auto storedValueInt = convertFloatToInteger( + builder, loc, storedValueFloat, storageScalarOrTensorType, quantizedType.isSigned()); // Clamp stored value it if the storage type is bound - storedValue = - clampScalarOrTensor(builder, loc, storedValue, inputShape, quantizedType); - return storedValue; + auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt, + inputShape, quantizedType); + return storedValueClamped; } class QuantizeCastOpConversion : public OpConversionPattern { diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir new file mode 100644 index 0000000000000..0e151c514eebc --- /dev/null +++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir @@ -0,0 +1,165 @@ +// RUN: mlir-opt %s --lower-quant-ops --split-input-file | FileCheck %s + +// CHECK-LABEL: @qcast_per_layer_scalar +// CHECK-SAME: %[[ARG_0:.*]]: f32 + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : f32 to i8 + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : i8 to !quant.uniform +// CHECK: return %[[STORED_QUANT]] : !quant.uniform + +!qalias = !quant.uniform +func.func @qcast_per_layer_scalar(%arg0: f32) -> !qalias { + %0 = quant.qcast %arg0 : f32 to !qalias + return %0 : !qalias +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_scalar_bounds +// CHECK-SAME: %[[ARG_0:.*]]: f32 + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8 + +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[SCALED]] : f32 to i8 + +// CHECK-DAG: %[[C_NEG_5:.*]] = arith.constant -5 : i8 +// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8 +// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_5]] : i8 +// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8 + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform:f32, 2.000000e+00> +// CHECK: return %[[STORED_QUANT]] : !quant.uniform:f32, 2.000000e+00> + +!qalias = !quant.uniform:f32, 2.0> +func.func @qcast_per_layer_scalar_bounds(%arg0: f32) -> !qalias { + %0 = quant.qcast %arg0 : f32 to !qalias + return %0 : !qalias +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_scalar_unsigned_bounds +// CHECK-SAME: %[[ARG_0:.*]]: f32 + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8 + +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptoui %[[SCALED]] : f32 to i8 + +// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : i8 +// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8 +// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxui %[[STORED_INT]], %[[C_2]] : i8 +// CHECK: %[[STORED_CLAMPED:.*]] = arith.minui %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8 + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform:f32, 2.000000e+00> +// CHECK: return %[[STORED_QUANT]] : !quant.uniform:f32, 2.000000e+00> + +!qalias = !quant.uniform:f32, 2.0> +func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias { + %0 = quant.qcast %arg0 : f32 to !qalias + return %0 : !qalias +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x?x5xf32> + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index + +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<3x?x5xf32> +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32> +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32> + +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8> +// CHECK: %[[ZERO_POINT_TENSOR_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32> +// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_TENSOR_FLOAT]] : tensor<3x?x5xf32> +// CHECK: %[[STORED_FLOAT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_FLOAT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] : tensor<3x?x5x!quant.uniform> + +!qalias = !quant.uniform +func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qalias> { + %0 = quant.qcast %arg0 : tensor<3x?x5xf32> to tensor<3x?x5x!qalias> + return %0 : tensor<3x?x5x!qalias> +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_ranked_bounds +// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32> + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]] : tensor<3x5xf32> +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_SPLAT]] : tensor<3x5xf32> + +// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<3x5xi8> +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor<3x5xi8> to tensor<3x5xf32> + +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<3x5xf32> +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<3x5xf32> to tensor<3x5xi8> + +// CHECK-DAG: %[[C_NEG_8:.*]] = arith.constant -8 : i8 +// CHECK-DAG: %[[C_7:.*]] = arith.constant 7 : i8 +// CHECK-DAG: %[[SPLAT_NEG_8:.*]] = tensor.splat %[[C_NEG_8]] : tensor<3x5xi8> +// CHECK-DAG: %[[SPLAT_7:.*]] = tensor.splat %[[C_7]] : tensor<3x5xi8> +// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[SPLAT_NEG_8]] : tensor<3x5xi8> +// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[SPLAT_7]] : tensor<3x5xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : tensor<3x5xi8> to tensor<3x5x!quant.uniform:f32, 2.000000e+00:10>> +// CHECK: return %[[STORED_QUANT]] : tensor<3x5x!quant.uniform:f32, 2.000000e+00:10>> + +!qalias = !quant.uniform:f32, 2.0:10> +func.func @qcast_per_layer_ranked_bounds(%arg0: tensor<3x5xf32>) -> tensor<3x5x!qalias> { + %0 = quant.qcast %arg0 : tensor<3x5xf32> to tensor<3x5x!qalias> + return %0 : tensor<3x5x!qalias> +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_unranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32> + +// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor +// CHECK: %[[SIZE:.*]] = shape.num_elements %[[SHAPE]] : tensor -> index +// CHECK: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[SIZE]] : tensor<1xindex> +// CHECK: %[[RANKED_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index + +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[RANKED_INPUT]], %[[C_0]] : tensor +// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor +// CHECK: %[[SCALED:.*]] = arith.divf %[[RANKED_INPUT]], %[[SCALE_SPLAT]] : tensor + +// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor to tensor +// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor to tensor + +// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[STORED_INT]](%[[SHAPE]]) : (tensor, tensor) -> tensor<*xi8> +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform> + +!qalias = !quant.uniform +func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> { + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> + return %0 : tensor<*x!qalias> +} + From 3a913979546c071eaae5e681d4313e5abf18e523 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Wed, 17 Jul 2024 12:45:58 -0400 Subject: [PATCH 09/22] Completed unit tests for 'quant.qcast' --- .../Quant/Transforms/LowerQuantOps.cpp | 2 +- mlir/test/Dialect/Quant/lower-quant-ops.mlir | 117 ++++++++++++++++++ 2 files changed, 118 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index e6942899cd638..1ffbd7032ae58 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -258,7 +258,7 @@ Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input, // Skip unnecessary computations if no zero point is given Value storedValueFloat = scaledValue; - if (matchPattern(zeroPoint, m_NonZero())) { + if (!matchPattern(zeroPoint, m_Zero())) { // Convert zero point to tensor if necessary zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType, inputShape); diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir index 0e151c514eebc..1aa14c615c6e5 100644 --- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir +++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir @@ -163,3 +163,120 @@ func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> { return %0 : tensor<*x!qalias> } +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> + +// CHECK-LABEL: @qcast_per_channel_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x?x?x5xf32> + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8> + +// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<4x?x?x5xf32> +// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM_2:.*]] = tensor.dim %[[ARG_0]], %[[C_2]] : tensor<4x?x?x5xf32> +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xi8> + +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xi8>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8): +// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8 +// CHECK: linalg.yield %[[STORED_INT]] : i8 +// CHECK: } -> tensor<4x?x?x5xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x?x?x5xi8> to tensor<4x?x?x5x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] : tensor<4x?x?x5x!quant.uniform> + +!qalias = !quant.uniform +func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias> { + %0 = "quant.qcast"(%arg0) : (tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias> + return %0 : tensor<4x?x?x5x!qalias> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-LABEL: @qcast_per_channel_ranked_bounds +// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x2x5xf32> + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<0> : tensor<2xi8> + +// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<4x2x5xi8> +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x2x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x2x5xi8>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8): +// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8 +// CHECK: %[[C_NEG_8:.*]] = arith.constant -8 : i8 +// CHECK: %[[C_7:.*]] = arith.constant 7 : i8 +// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_8]] : i8 +// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_7]] : i8 +// CHECK: linalg.yield %[[STORED_CLAMPED]] : i8 +// CHECK: } -> tensor<4x2x5xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x2x5xi8> to tensor<4x2x5x!quant.uniform:f32:1, {2.000000e+00,3.000000e+00}>> +// CHECK: return %[[STORED_QUANT]] : tensor<4x2x5x!quant.uniform:f32:1, {2.000000e+00,3.000000e+00}>> + +!qalias = !quant.uniform:f32:1, {2.0, 3.0}> +func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> { + %0 = "quant.qcast"(%arg0) : (tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> + return %0 : tensor<4x2x5x!qalias> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-LABEL: @qcast_per_channel_unranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32> + +// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor +// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index +// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index +// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor, index) -> (tensor, tensor) +// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor -> index +// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor, index) -> (tensor, tensor) +// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor -> index + +// CHECK: %[[CHANNEL_AXIS_SIZE:.*]] = arith.constant 3 : index +// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[CHANNEL_AXIS_SIZE]], %[[SIZE_RIGHT]] : tensor<3xindex> +// CHECK: %[[FLAT_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8> + +// CHECK: %[[C_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_0]] : tensor +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_2]] : tensor +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor + +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[FLAT_INPUT]], %[[SCALES]], %[[ZERO_POINTS]] : tensor, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8): +// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8 +// CHECK: linalg.yield %[[STORED_INT]] : i8 +// CHECK: } -> tensor + +// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor, tensor) -> tensor<*xi8> +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform> + +!qalias = !quant.uniform +func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> { + %0 = "quant.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!qalias> + return %0 : tensor<*x!qalias> +} + From b50b2e12ebc9d981ac0d39b3c69d5c0f5f3c0afe Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Mon, 22 Jul 2024 09:34:57 -0400 Subject: [PATCH 10/22] Unit test for 0D tensor --- mlir/test/Dialect/Quant/lower-quant-ops.mlir | 29 ++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir index 1aa14c615c6e5..ad3be870916dc 100644 --- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir +++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir @@ -86,9 +86,9 @@ func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias { // CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8> // CHECK: %[[ZERO_POINT_TENSOR_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32> // CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_TENSOR_FLOAT]] : tensor<3x?x5xf32> -// CHECK: %[[STORED_FLOAT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8> +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8> -// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_FLOAT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform> +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform> // CHECK: return %[[STORED_QUANT]] : tensor<3x?x5x!quant.uniform> !qalias = !quant.uniform @@ -99,6 +99,31 @@ func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qal // ----- +// CHECK-LABEL: @qcast_per_layer_0d +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor + +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor to tensor +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor to tensor + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor to tensor> +// CHECK: return %[[STORED_QUANT]] : tensor> + +!qalias = !quant.uniform +func.func @qcast_per_layer_0d(%arg0: tensor) -> tensor { + %0 = quant.qcast %arg0 : tensor to tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: @qcast_per_layer_ranked_bounds // CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32> From 9fc4be12215ad3bb9940203bda6d0114b05fa9c9 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Mon, 22 Jul 2024 12:36:41 -0400 Subject: [PATCH 11/22] Support for 'quant.dcast' and unit tests --- .../Quant/Transforms/LowerQuantOps.cpp | 217 ++++++++++++++++-- mlir/test/Dialect/Quant/lower-quant-ops.mlir | 95 ++++++-- 2 files changed, 264 insertions(+), 48 deletions(-) diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 1ffbd7032ae58..1b2d36dc672cf 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -30,26 +30,6 @@ namespace quant { namespace { -//===----------------------------------------------------------------------===// -// DequantizeCastOp -//===----------------------------------------------------------------------===// - -class DequantizeCastOpConversion : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - return success(); - } -}; - - -//===----------------------------------------------------------------------===// -// QuantizeCastOp -//===----------------------------------------------------------------------===// - // If 'inputType' is a tensor, return its element type. If it is a scalar, // return it as is. Type getScalarType(Type inputType) { @@ -243,8 +223,9 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, return builder.create(loc, resultType, input); } -// Quantize a floating-point input using the given scale, input shape, and -// storage type bounds in the given quantized type. +// Quantize a floating-point input using the given input shape, scale, and +// zero point. The stored value is clamped using the storage bounds encoded in +// the given quantized type. Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input, ArrayRef inputShape, Value scale, Value zeroPoint, QuantizedType quantizedType) { @@ -286,6 +267,196 @@ Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input, return storedValueClamped; } +Value dequantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input, + ArrayRef inputShape, Value scale, + Value zeroPoint, QuantizedType quantizedType) { + // Convert scale to tensor if necessary + auto inputType = input.getType(); + scale = getScalarOrTensorConstant( + builder, loc, scale, inputType, inputShape); + + // Convert stored value to float + auto result = convertIntegerToFloat( + builder, loc, input, scale.getType(), quantizedType.isSigned()); + + // Skip unnecessary computations if no zero point is given + if (!matchPattern(zeroPoint, m_Zero())) { + // Convert zero point to tensor if necessary + zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType, + inputShape); + + // Convert zero point from storage to expressed type + zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, + scale.getType(), + quantizedType.isSigned()); + + // Subtract zero point to stored value + result = builder.create(loc, result, zeroPoint); + } + + // Multiply by scale + result = builder.create(loc, result, scale); + return result; +} + + +//===----------------------------------------------------------------------===// +// DequantizeCastOp +//===----------------------------------------------------------------------===// + +class DequantizeCastOpConversion : public OpConversionPattern { + + Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input, + UniformQuantizedType quantizedType) const { + + // Create scale and zero point constants + auto expressedType = quantizedType.getExpressedType(); + auto storageType = quantizedType.getStorageType(); + auto scaleAttr = + builder.getFloatAttr(expressedType, quantizedType.getScale()); + auto scale = + builder.create(loc, expressedType, scaleAttr); + auto zeroPointAttr = + builder.getIntegerAttr(storageType, quantizedType.getZeroPoint()); + auto zeroPoint = + builder.create(loc, storageType, zeroPointAttr); + + auto inputShape = getScalarOrTensorShape(builder, loc, input); + return dequantizeScalarOrTensor(builder, loc, input, inputShape, scale, + zeroPoint, quantizedType); + } + + Value convertPerLayer(OpBuilder &builder, Location loc, Value input, + UniformQuantizedType quantizedType) const { + // Flatten input if unranked + bool isUnranked = isa(input.getType()); + Value inputShape; + if (isUnranked) + std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input); + + // Process ranked tensor + auto result = convertPerLayerRanked(builder, loc, input, quantizedType); + + // Restore original shape if unranked + if (isUnranked) + result = restoreUnrankedTensorShape(builder, loc, result, inputShape); + + return result; + } + + Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input, + UniformQuantizedPerAxisType quantizedType, + int64_t channelAxis) const { + auto *context = builder.getContext(); + + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + auto scales = materializePerChannelScales(builder, loc, quantizedType); + auto zeroPoints = + materializePerChannelZeroPoints(builder, loc, quantizedType); + + auto storageType = quantizedType.getStorageType(); + auto initShape = tensor::getMixedSizes(builder, loc, input); + Value init = builder.create(loc, initShape, storageType); + + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + auto channelAxisAffineMap = AffineMap::get( + inputRank, 0, builder.getAffineDimExpr(channelAxis), context); + SmallVector indexingMaps{ + builder.getMultiDimIdentityMap(inputRank), + channelAxisAffineMap, + channelAxisAffineMap, + builder.getMultiDimIdentityMap(inputRank) + }; + auto storedValue = builder.create( + loc, + init.getType(), // resultType + ValueRange{input, scales, zeroPoints}, // inputs + ValueRange{init}, // outputs + indexingMaps, + iteratorTypes, + [&](OpBuilder& builder, Location loc, ValueRange args) { + assert(args.size() == 4); + auto expressedValue = args[0]; + auto scale = args[1]; + auto zeroPoint = args[2]; + + auto result = quantizeScalarOrTensor(builder, loc, expressedValue, {}, + scale, zeroPoint, quantizedType); + + builder.create(loc, result); + }) + .getResult(0); + + return storedValue; + } + + Value convertPerChannel(OpBuilder &builder, Location loc, Value input, + UniformQuantizedPerAxisType quantizedType) const { + // Flatten unranked tensor into a 3D ranked tensor if necessary + bool isUnranked = isa(input.getType()); + int64_t channelAxis = quantizedType.getQuantizedDimension(); + int64_t channelAxisSize = (int64_t) quantizedType.getScales().size(); + Value inputShape; + if (isUnranked) { + std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis( + builder, loc, input, channelAxis, channelAxisSize); + channelAxis = 1; + } + + // Work on a ranked tensor + auto result = convertPerChannelRanked(builder, loc, input, quantizedType, + channelAxis); + + // Restore original tensor shape if unranked + if (isUnranked) + result = restoreUnrankedTensorShape(builder, loc, result, inputShape); + + return result; + } + +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto input = op.getInput(); + auto quantizedType = + cast(getScalarType(op.getInput().getType())); + + // Convert quantized input to storage type + auto storageScalarOrTensorType = + getScalarOrTensorType(quantizedType.getStorageType(), input.getType()); + input = rewriter.create( + loc, storageScalarOrTensorType, input); + + // Flatten unranked tensor input + Value result; + if (auto uniformQuantizedType = + dyn_cast(quantizedType)) { + result = convertPerLayer(rewriter, loc, input, uniformQuantizedType); + } else if (auto uniformQuantizedPerAxisType = + dyn_cast(quantizedType)) { + result = + convertPerChannel(rewriter, loc, input, uniformQuantizedPerAxisType); + } else { + llvm_unreachable("unexpected quantized type"); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + + +//===----------------------------------------------------------------------===// +// QuantizeCastOp +//===----------------------------------------------------------------------===// + class QuantizeCastOpConversion : public OpConversionPattern { Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input, @@ -417,7 +588,7 @@ class QuantizeCastOpConversion : public OpConversionPattern(resultScalarType)) { result = convertPerChannel(rewriter, loc, input, quantizedType); } else { - llvm_unreachable("unexpected uniform quantized type"); + llvm_unreachable("unexpected quantized type"); } // Cast stored value to result quantized value diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir index ad3be870916dc..1030d5b20620b 100644 --- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir +++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir @@ -72,6 +72,31 @@ func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias { // ----- +// CHECK-LABEL: @qcast_per_layer_0d +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor + +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor to tensor +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor to tensor + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor to tensor> +// CHECK: return %[[STORED_QUANT]] : tensor> + +!qalias = !quant.uniform +func.func @qcast_per_layer_0d(%arg0: tensor) -> tensor { + %0 = quant.qcast %arg0 : tensor to tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: @qcast_per_layer_ranked // CHECK-SAME: %[[ARG_0:.*]]: tensor<3x?x5xf32> @@ -99,31 +124,6 @@ func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qal // ----- -// CHECK-LABEL: @qcast_per_layer_0d -// CHECK-SAME: %[[ARG_0:.*]]: tensor - -// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 - -// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor -// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor - -// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor -// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor to tensor -// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor -// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor to tensor - -// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor to tensor> -// CHECK: return %[[STORED_QUANT]] : tensor> - -!qalias = !quant.uniform -func.func @qcast_per_layer_0d(%arg0: tensor) -> tensor { - %0 = quant.qcast %arg0 : tensor to tensor - return %0 : tensor -} - -// ----- - // CHECK-LABEL: @qcast_per_layer_ranked_bounds // CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32> @@ -305,3 +305,48 @@ func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> return %0 : tensor<*x!qalias> } +// ----- + +// CHECK-LABEL: @dcast_per_layer_scalar +// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform to i8 + +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: return %[[EXPRESSED]] : f32 + +!qalias = !quant.uniform +func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 { + %0 = quant.dcast %arg0 : !qalias to f32 + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_scalar_unsigned +// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform to i8 + +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32 + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: return %[[EXPRESSED]] : f32 + +!qalias = !quant.uniform +func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 { + %0 = quant.dcast %arg0 : !qalias to f32 + return %0 : f32 +} + From 5b40f256c6dd31f511da5ff58608e9c28732a0b6 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Mon, 22 Jul 2024 14:33:38 -0400 Subject: [PATCH 12/22] Refactored quant.qcast and quant.dcast common code --- .../Quant/Transforms/LowerQuantOps.cpp | 418 +++++++----------- 1 file changed, 164 insertions(+), 254 deletions(-) diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 1b2d36dc672cf..77be3dd11d3ad 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -223,12 +223,11 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, return builder.create(loc, resultType, input); } -// Quantize a floating-point input using the given input shape, scale, and -// zero point. The stored value is clamped using the storage bounds encoded in -// the given quantized type. -Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input, - ArrayRef inputShape, Value scale, - Value zeroPoint, QuantizedType quantizedType) { +// Quantize a scalar or ranked tensor value. The stored value is clamped using +// the storage bounds encoded in the given quantized type. +Value quantizeValue(OpBuilder &builder, Location loc, Value input, + ArrayRef inputShape, Value scale, + Value zeroPoint, QuantizedType quantizedType) { // Convert scale to tensor if necessary auto inputType = input.getType(); scale = getScalarOrTensorConstant( @@ -267,9 +266,10 @@ Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input, return storedValueClamped; } -Value dequantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input, - ArrayRef inputShape, Value scale, - Value zeroPoint, QuantizedType quantizedType) { +// Dequantize a scalar or ranked tensor value. +Value dequantizeValue(OpBuilder &builder, Location loc, Value input, + ArrayRef inputShape, Value scale, + Value zeroPoint, QuantizedType quantizedType) { // Convert scale to tensor if necessary auto inputType = input.getType(); scale = getScalarOrTensorConstant( @@ -299,125 +299,168 @@ Value dequantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input, return result; } +// Convert a scalar or ranked tensor input with the given scale and zero point +// values. +// +// - input +// Scalar or ranked tensor value. +// +// - inputShape +// If 'input' is a tensor, combination or attributes/values representing its +// static/dynamic dimensions. If 'input' is a scalar, empty list. +// +// - scale +// Scale as a scalar value. +// +// - zeroPoint +// Zero point as a scalar value. +// +// - quantizedType +// Scalar quantized type of the result ('quant.qcast') or of the input +// ('quant.dcast'). +// +Value convertRanked(OpBuilder &builder, Location loc, Operation *op, + Value input, ArrayRef inputShape, Value scale, + Value zeroPoint, QuantizedType quantizedType) { + if (isa(op)) + return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint, + quantizedType); + if (isa(op)) + return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint, + quantizedType); + llvm_unreachable("unexpected quant op"); +} -//===----------------------------------------------------------------------===// -// DequantizeCastOp -//===----------------------------------------------------------------------===// +Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op, + Value input, UniformQuantizedType quantizedType) { + + // Create scale and zero point constants + auto expressedType = quantizedType.getExpressedType(); + auto storageType = quantizedType.getStorageType(); + auto scaleAttr = + builder.getFloatAttr(expressedType, quantizedType.getScale()); + auto scale = builder.create(loc, expressedType, scaleAttr); + auto zeroPointAttr = + builder.getIntegerAttr(storageType, quantizedType.getZeroPoint()); + auto zeroPoint = + builder.create(loc, storageType, zeroPointAttr); + + auto inputShape = getScalarOrTensorShape(builder, loc, input); + return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint, + quantizedType); +} -class DequantizeCastOpConversion : public OpConversionPattern { - - Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input, - UniformQuantizedType quantizedType) const { - - // Create scale and zero point constants - auto expressedType = quantizedType.getExpressedType(); - auto storageType = quantizedType.getStorageType(); - auto scaleAttr = - builder.getFloatAttr(expressedType, quantizedType.getScale()); - auto scale = - builder.create(loc, expressedType, scaleAttr); - auto zeroPointAttr = - builder.getIntegerAttr(storageType, quantizedType.getZeroPoint()); - auto zeroPoint = - builder.create(loc, storageType, zeroPointAttr); - - auto inputShape = getScalarOrTensorShape(builder, loc, input); - return dequantizeScalarOrTensor(builder, loc, input, inputShape, scale, +Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op, + Value input, UniformQuantizedType quantizedType) { + // Flatten input if unranked + bool isUnranked = isa(input.getType()); + Value inputShape; + if (isUnranked) + std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input); + + // Process ranked tensor + auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType); + + // Restore original shape if unranked + if (isUnranked) + result = restoreUnrankedTensorShape(builder, loc, result, inputShape); + + return result; +} + +Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, + Value input, + UniformQuantizedPerAxisType quantizedType, + int64_t channelAxis) { + auto *context = builder.getContext(); + + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + auto scales = materializePerChannelScales(builder, loc, quantizedType); + auto zeroPoints = + materializePerChannelZeroPoints(builder, loc, quantizedType); + + auto storageType = quantizedType.getStorageType(); + auto initShape = tensor::getMixedSizes(builder, loc, input); + Value init = builder.create(loc, initShape, storageType); + + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + auto channelAxisAffineMap = AffineMap::get( + inputRank, 0, builder.getAffineDimExpr(channelAxis), context); + SmallVector indexingMaps{ + builder.getMultiDimIdentityMap(inputRank), + channelAxisAffineMap, + channelAxisAffineMap, + builder.getMultiDimIdentityMap(inputRank) + }; + auto storedValue = builder.create( + loc, + init.getType(), // resultType + ValueRange{input, scales, zeroPoints}, // inputs + ValueRange{init}, // outputs + indexingMaps, + iteratorTypes, + [&](OpBuilder& builder, Location loc, ValueRange args) { + assert(args.size() == 4); + auto expressedValue = args[0]; + auto scale = args[1]; + auto zeroPoint = args[2]; + + auto result = convertRanked(builder, loc, op, expressedValue, {}, scale, zeroPoint, quantizedType); + + builder.create(loc, result); + }) + .getResult(0); + + return storedValue; +} + +Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op, + Value input, + UniformQuantizedPerAxisType quantizedType) { + // Flatten unranked tensor into a 3D ranked tensor if necessary + bool isUnranked = isa(input.getType()); + int64_t channelAxis = quantizedType.getQuantizedDimension(); + int64_t channelAxisSize = (int64_t) quantizedType.getScales().size(); + Value inputShape; + if (isUnranked) { + std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis( + builder, loc, input, channelAxis, channelAxisSize); + channelAxis = 1; } - Value convertPerLayer(OpBuilder &builder, Location loc, Value input, - UniformQuantizedType quantizedType) const { - // Flatten input if unranked - bool isUnranked = isa(input.getType()); - Value inputShape; - if (isUnranked) - std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input); + // Work on a ranked tensor + auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType, + channelAxis); - // Process ranked tensor - auto result = convertPerLayerRanked(builder, loc, input, quantizedType); + // Restore original tensor shape if unranked + if (isUnranked) + result = restoreUnrankedTensorShape(builder, loc, result, inputShape); - // Restore original shape if unranked - if (isUnranked) - result = restoreUnrankedTensorShape(builder, loc, result, inputShape); + return result; +} - return result; - } +Value convertQuantized(OpBuilder &builder, Location loc, Operation *op, + Value input, Type quantizedType) { + if (auto uniformQuantizedType = dyn_cast(quantizedType)) + return convertPerLayer(builder, loc, op, input, uniformQuantizedType); - Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input, - UniformQuantizedPerAxisType quantizedType, - int64_t channelAxis) const { - auto *context = builder.getContext(); - - auto inputType = cast(input.getType()); - auto inputRank = inputType.getRank(); - - auto scales = materializePerChannelScales(builder, loc, quantizedType); - auto zeroPoints = - materializePerChannelZeroPoints(builder, loc, quantizedType); - - auto storageType = quantizedType.getStorageType(); - auto initShape = tensor::getMixedSizes(builder, loc, input); - Value init = builder.create(loc, initShape, storageType); - - SmallVector iteratorTypes( - inputRank, utils::IteratorType::parallel); - auto channelAxisAffineMap = AffineMap::get( - inputRank, 0, builder.getAffineDimExpr(channelAxis), context); - SmallVector indexingMaps{ - builder.getMultiDimIdentityMap(inputRank), - channelAxisAffineMap, - channelAxisAffineMap, - builder.getMultiDimIdentityMap(inputRank) - }; - auto storedValue = builder.create( - loc, - init.getType(), // resultType - ValueRange{input, scales, zeroPoints}, // inputs - ValueRange{init}, // outputs - indexingMaps, - iteratorTypes, - [&](OpBuilder& builder, Location loc, ValueRange args) { - assert(args.size() == 4); - auto expressedValue = args[0]; - auto scale = args[1]; - auto zeroPoint = args[2]; - - auto result = quantizeScalarOrTensor(builder, loc, expressedValue, {}, - scale, zeroPoint, quantizedType); - - builder.create(loc, result); - }) - .getResult(0); - - return storedValue; - } + if (auto uniformQuantizedPerAxisType = + dyn_cast(quantizedType)) + return convertPerChannel(builder, loc, op, input, + uniformQuantizedPerAxisType); - Value convertPerChannel(OpBuilder &builder, Location loc, Value input, - UniformQuantizedPerAxisType quantizedType) const { - // Flatten unranked tensor into a 3D ranked tensor if necessary - bool isUnranked = isa(input.getType()); - int64_t channelAxis = quantizedType.getQuantizedDimension(); - int64_t channelAxisSize = (int64_t) quantizedType.getScales().size(); - Value inputShape; - if (isUnranked) { - std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis( - builder, loc, input, channelAxis, channelAxisSize); - channelAxis = 1; - } - - // Work on a ranked tensor - auto result = convertPerChannelRanked(builder, loc, input, quantizedType, - channelAxis); - - // Restore original tensor shape if unranked - if (isUnranked) - result = restoreUnrankedTensorShape(builder, loc, result, inputShape); - - return result; - } + llvm_unreachable("unexpected quantized type"); +} + +//===----------------------------------------------------------------------===// +// DequantizeCastOp +//===----------------------------------------------------------------------===// -public: +struct DequantizeCastOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -434,19 +477,7 @@ class DequantizeCastOpConversion : public OpConversionPattern( loc, storageScalarOrTensorType, input); - // Flatten unranked tensor input - Value result; - if (auto uniformQuantizedType = - dyn_cast(quantizedType)) { - result = convertPerLayer(rewriter, loc, input, uniformQuantizedType); - } else if (auto uniformQuantizedPerAxisType = - dyn_cast(quantizedType)) { - result = - convertPerChannel(rewriter, loc, input, uniformQuantizedPerAxisType); - } else { - llvm_unreachable("unexpected quantized type"); - } - + auto result = convertQuantized(rewriter, loc, op, input, quantizedType); rewriter.replaceOp(op, result); return success(); } @@ -457,120 +488,7 @@ class DequantizeCastOpConversion : public OpConversionPattern { - - Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input, - UniformQuantizedType quantizedType) const { - - // Create scale and zero point constants - auto expressedType = quantizedType.getExpressedType(); - auto storageType = quantizedType.getStorageType(); - auto scaleAttr = - builder.getFloatAttr(expressedType, quantizedType.getScale()); - auto scale = - builder.create(loc, expressedType, scaleAttr); - auto zeroPointAttr = - builder.getIntegerAttr(storageType, quantizedType.getZeroPoint()); - auto zeroPoint = - builder.create(loc, storageType, zeroPointAttr); - - auto inputShape = getScalarOrTensorShape(builder, loc, input); - return quantizeScalarOrTensor(builder, loc, input, inputShape, scale, - zeroPoint, quantizedType); - } - - Value convertPerLayer(OpBuilder &builder, Location loc, Value input, - UniformQuantizedType quantizedType) const { - // Flatten input if unranked - bool isUnranked = isa(input.getType()); - Value inputShape; - if (isUnranked) - std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input); - - // Process ranked tensor - auto result = convertPerLayerRanked(builder, loc, input, quantizedType); - - // Restore original shape if unranked - if (isUnranked) - result = restoreUnrankedTensorShape(builder, loc, result, inputShape); - - return result; - } - - Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input, - UniformQuantizedPerAxisType quantizedType, - int64_t channelAxis) const { - auto *context = builder.getContext(); - - auto inputType = cast(input.getType()); - auto inputRank = inputType.getRank(); - - auto scales = materializePerChannelScales(builder, loc, quantizedType); - auto zeroPoints = - materializePerChannelZeroPoints(builder, loc, quantizedType); - - auto storageType = quantizedType.getStorageType(); - auto initShape = tensor::getMixedSizes(builder, loc, input); - Value init = builder.create(loc, initShape, storageType); - - SmallVector iteratorTypes( - inputRank, utils::IteratorType::parallel); - auto channelAxisAffineMap = AffineMap::get( - inputRank, 0, builder.getAffineDimExpr(channelAxis), context); - SmallVector indexingMaps{ - builder.getMultiDimIdentityMap(inputRank), - channelAxisAffineMap, - channelAxisAffineMap, - builder.getMultiDimIdentityMap(inputRank) - }; - auto storedValue = builder.create( - loc, - init.getType(), // resultType - ValueRange{input, scales, zeroPoints}, // inputs - ValueRange{init}, // outputs - indexingMaps, - iteratorTypes, - [&](OpBuilder& builder, Location loc, ValueRange args) { - assert(args.size() == 4); - auto expressedValue = args[0]; - auto scale = args[1]; - auto zeroPoint = args[2]; - - auto result = quantizeScalarOrTensor(builder, loc, expressedValue, {}, - scale, zeroPoint, quantizedType); - - builder.create(loc, result); - }) - .getResult(0); - - return storedValue; - } - - Value convertPerChannel(OpBuilder &builder, Location loc, Value input, - UniformQuantizedPerAxisType quantizedType) const { - // Flatten unranked tensor into a 3D ranked tensor if necessary - bool isUnranked = isa(input.getType()); - int64_t channelAxis = quantizedType.getQuantizedDimension(); - int64_t channelAxisSize = (int64_t) quantizedType.getScales().size(); - Value inputShape; - if (isUnranked) { - std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis( - builder, loc, input, channelAxis, channelAxisSize); - channelAxis = 1; - } - - // Work on a ranked tensor - auto result = convertPerChannelRanked(builder, loc, input, quantizedType, - channelAxis); - - // Restore original tensor shape if unranked - if (isUnranked) - result = restoreUnrankedTensorShape(builder, loc, result, inputShape); - - return result; - } - -public: +struct QuantizeCastOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -578,18 +496,10 @@ class QuantizeCastOpConversion : public OpConversionPattern(resultScalarType)) { - result = convertPerLayer(rewriter, loc, input, quantizedType); - } else if (auto quantizedType = - dyn_cast(resultScalarType)) { - result = convertPerChannel(rewriter, loc, input, quantizedType); - } else { - llvm_unreachable("unexpected quantized type"); - } + auto result = convertQuantized(rewriter, loc, op, input, quantizedType); // Cast stored value to result quantized value rewriter.replaceOpWithNewOp( From 6eabe11b9880d2e22b2d4077d4f2639845219cce Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Mon, 22 Jul 2024 17:29:17 -0400 Subject: [PATCH 13/22] Unit test for 'quant.dcast' lowering with bug fixes --- .../Quant/Transforms/LowerQuantOps.cpp | 172 ++++++++++-- mlir/test/Dialect/Quant/lower-quant-ops.mlir | 255 ++++++++++++++---- 2 files changed, 360 insertions(+), 67 deletions(-) diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 77be3dd11d3ad..4adeb9218ff8e 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -76,6 +76,19 @@ Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar, return tensorConstant; } +// Reshape an unranked tensor into a 1D ranked tensor. +// +// - input +// Unranked tensor. +// +// Return values: +// +// - flatInput +// 1D ranked, dynamically shaped tensor. +// +// - inputShape +// 1D extent tensor containing the shape of the original unranked input. +// std::pair flattenUnrankedTensor(OpBuilder &builder, Location loc, Value input) { // Get unranked input shape and total size @@ -100,6 +113,28 @@ std::pair flattenUnrankedTensor(OpBuilder &builder, Location loc, return std::make_pair(flatInput, inputShape); } +// Reshape an unranked tensor into a 3D ranked tensor where the central +// dimension of the result tensor corresponds to dimension 'axis' of the input +// tensor. +// +// - input +// Unranked tensor. +// +// - axis +// Index of the input dimension around which other input dimiensions will be +// collapsed. +// +// - axisSize +// Size of input dimension 'axis'. +// +// Return values: +// +// - flatInput +// 3D ranked tensor of shape [?, axisSize, ?]. +// +// - inputShape +// 1D extent tensor containing the shape of the original unranked input. +// std::pair flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input, @@ -142,6 +177,14 @@ std::pair flattenUnrankedTensorAroundAxis(OpBuilder &builder, return std::make_pair(flatInput, inputShape); } +// Reshape an input tensor into its original unranked shape. +// +// - input +// Ranked tensor. +// +// - inputShape +// 1D extent tensor. +// Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, Value inputShape) { auto inputType = cast(input.getType()); @@ -150,6 +193,15 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, return builder.create(loc, unrankedType, input, inputShape); } +// Create a tensor constant containing all scales in a per-channel quantized +// type. Example: +// +// !quant.uniform +// +// produces +// +// %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32> +// Value materializePerChannelScales(OpBuilder &builder, Location loc, UniformQuantizedPerAxisType quantizedType) { auto scales = quantizedType.getScales(); @@ -162,6 +214,15 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc, return builder.create(loc, tensorType, scalesAttr); } +// Create a tensor constant containing all zero points in a per-channel +// quantized type. Example: +// +// !quant.uniform +// +// produces +// +// %cst = arith.constant dense<[10, 20]> : tensor<2xi8> +// Value materializePerChannelZeroPoints( OpBuilder &builder, Location loc, UniformQuantizedPerAxisType quantizedType) { @@ -178,6 +239,19 @@ Value materializePerChannelZeroPoints( return builder.create(loc, tensorType, zeroPointsAttr); } +// Clamp the given scalar or tensor input using the storage bounds encoded in +// the given quantized type, if present. +// +// - input +// Scalar or ranked tensor input. The element type must match the storage type +// of 'quantizedType'. +// +// - inputShape +// If 'input' is a tensor, combination of attributes/values representing its +// static/dynamic dimensions. If 'input' is a scalar, empty list. +// +// - quantizedType +// Per-axis or per-channel quantized type. Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, ArrayRef inputShape, QuantizedType quantizedType) { @@ -209,6 +283,7 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, return input; } +// Emit op 'arith.fptosi' or 'arith.fptoui'. Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input, Type resultType, bool isSigned) { if (isSigned) @@ -216,6 +291,7 @@ Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input, return builder.create(loc, resultType, input); } +// Emit op 'arith.sitofp' or 'arith.uitofp'. Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, Type resultType, bool isSigned) { if (isSigned) @@ -225,6 +301,8 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, // Quantize a scalar or ranked tensor value. The stored value is clamped using // the storage bounds encoded in the given quantized type. +// +// See function 'convertRanked()' below for a description of the arguments. Value quantizeValue(OpBuilder &builder, Location loc, Value input, ArrayRef inputShape, Value scale, Value zeroPoint, QuantizedType quantizedType) { @@ -266,7 +344,9 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input, return storedValueClamped; } -// Dequantize a scalar or ranked tensor value. +// Dequantize a scalar or ranked tensor input. +// +// See function 'convertRanked()' below for a description of the arguments. Value dequantizeValue(OpBuilder &builder, Location loc, Value input, ArrayRef inputShape, Value scale, Value zeroPoint, QuantizedType quantizedType) { @@ -310,10 +390,10 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input, // static/dynamic dimensions. If 'input' is a scalar, empty list. // // - scale -// Scale as a scalar value. +// Scale as a floating-point scalar value. // // - zeroPoint -// Zero point as a scalar value. +// Zero point as an integer scalar value. // // - quantizedType // Scalar quantized type of the result ('quant.qcast') or of the input @@ -331,9 +411,20 @@ Value convertRanked(OpBuilder &builder, Location loc, Operation *op, llvm_unreachable("unexpected quant op"); } +// Convert an operation using per-layer quantization with a scalar or ranked +// tensor input. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar or ranked tensor. +// +// - quantizedType +// Per-layer quantized type. +// Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op, Value input, UniformQuantizedType quantizedType) { - // Create scale and zero point constants auto expressedType = quantizedType.getExpressedType(); auto storageType = quantizedType.getStorageType(); @@ -350,6 +441,17 @@ Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op, quantizedType); } +// Convert an operation using per-layer quantization. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar, ranked tensor, or unranked tensor. +// +// - quantizedType +// Per-layer quantized type. +// Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op, Value input, UniformQuantizedType quantizedType) { // Flatten input if unranked @@ -368,6 +470,18 @@ Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op, return result; } +// Convert an operation using per-channel quantization and a scalar or ranked +// tensor as an input. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar or ranked tensor. +// +// - quantizedType +// Per-channel quantized type. +// Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, Value input, UniformQuantizedPerAxisType quantizedType, @@ -381,9 +495,11 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, auto zeroPoints = materializePerChannelZeroPoints(builder, loc, quantizedType); - auto storageType = quantizedType.getStorageType(); + auto elementType = isa(inputType.getElementType()) + ? quantizedType.getStorageType() + : quantizedType.getExpressedType(); auto initShape = tensor::getMixedSizes(builder, loc, input); - Value init = builder.create(loc, initShape, storageType); + Value init = builder.create(loc, initShape, elementType); SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); @@ -395,7 +511,7 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank) }; - auto storedValue = builder.create( + auto result = builder.create( loc, init.getType(), // resultType ValueRange{input, scales, zeroPoints}, // inputs @@ -404,20 +520,31 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, iteratorTypes, [&](OpBuilder& builder, Location loc, ValueRange args) { assert(args.size() == 4); - auto expressedValue = args[0]; + auto input = args[0]; auto scale = args[1]; auto zeroPoint = args[2]; - auto result = convertRanked(builder, loc, op, expressedValue, {}, scale, + auto result = convertRanked(builder, loc, op, input, {}, scale, zeroPoint, quantizedType); builder.create(loc, result); }) .getResult(0); - return storedValue; + return result; } +// Convert an operation using per-channel quantization. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar, ranked tensor, or unranked tensor. +// +// - quantizedType +// Per-channel quantized type. +// Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op, Value input, UniformQuantizedPerAxisType quantizedType) { @@ -443,6 +570,19 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op, return result; } +// Convert a quantization operation. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar, ranked tensor, or unranked tensor. The element type matches +// the storage type (quant.dcast) or expressed type (quant.qcast) of +// 'quantizedType'. +// +// - quantizedType +// Per-layer or per-channel quantized type. +// Value convertQuantized(OpBuilder &builder, Location loc, Operation *op, Value input, Type quantizedType) { if (auto uniformQuantizedType = dyn_cast(quantizedType)) @@ -456,10 +596,7 @@ Value convertQuantized(OpBuilder &builder, Location loc, Operation *op, llvm_unreachable("unexpected quantized type"); } -//===----------------------------------------------------------------------===// -// DequantizeCastOp -//===----------------------------------------------------------------------===// - +// Lowering pattern for 'quant.dcast' struct DequantizeCastOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -478,16 +615,13 @@ struct DequantizeCastOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir index 1030d5b20620b..6bba9f5c03772 100644 --- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir +++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir @@ -1,5 +1,209 @@ // RUN: mlir-opt %s --lower-quant-ops --split-input-file | FileCheck %s +// CHECK-LABEL: @dcast_per_layer_scalar +// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform to i8 + +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: return %[[EXPRESSED]] : f32 + +!qalias = !quant.uniform +func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 { + %0 = quant.dcast %arg0 : !qalias to f32 + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_scalar_unsigned +// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform to i8 + +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32 + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: return %[[EXPRESSED]] : f32 + +!qalias = !quant.uniform +func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 { + %0 = quant.dcast %arg0 : !qalias to f32 + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_0d +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor> to tensor + +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor to tensor +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor to tensor + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor +// CHECK: return %[[EXPRESSED]] : tensor + +!qalias = !quant.uniform +func.func @dcast_per_layer_0d(%arg0: tensor) -> tensor { + %0 = quant.dcast %arg0 : tensor to tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<3x?x5x!quant.uniform> to tensor<3x?x5xi8> +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[C_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_INT]], %[[C_1]] : tensor<3x?x5xi8> +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32> +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5xf32> +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8> +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32> + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<3x?x5xf32> +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32> +// CHECK: return %[[EXPRESSED]] : tensor<3x?x5xf32> + +!qalias = !quant.uniform +func.func @dcast_per_layer_ranked(%arg0: tensor<3x?x5x!qalias>) -> tensor<3x?x5xf32> { + %0 = quant.dcast %arg0 : tensor<3x?x5x!qalias> to tensor<3x?x5xf32> + return %0 : tensor<3x?x5xf32> +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_unranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform> to tensor<*xi8> +// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[STORED_INT]] : tensor<*xi8> -> tensor +// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor -> index +// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex> +// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_INT]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<1xindex>) -> tensor +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[C_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_COLLAPSED]] : tensor to tensor +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor to tensor + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor + +// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[EXPRESSED]](%[[INPUT_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32> + +!qalias = !quant.uniform +func.func @dcast_per_layer_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> { + %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> + +// CHECK-LABEL: @dcast_per_channel_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<4x?x?x5x!quant.uniform> to tensor<4x?x?x5xi8> + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8> +// CHECK: %[[C_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_1]] : tensor<4x?x?x5xi8> +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_2]] : tensor<4x?x?x5xi8> +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xf32> +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[STORED_TENSOR]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xi8>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xf32>) { +// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32): +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: linalg.yield %[[EXPRESSED]] : f32 +// CHECK: } -> tensor<4x?x?x5xf32> +// CHECK: return %[[GENERIC]] : tensor<4x?x?x5xf32> + +!qalias = !quant.uniform +func.func @dcast_per_channel_ranked(%arg0: tensor<4x?x?x5x!qalias>) -> tensor<4x?x?x5xf32> { + %0 = quant.dcast %arg0 : tensor<4x?x?x5x!qalias> to tensor<4x?x?x5xf32> + return %0 : tensor<4x?x?x5xf32> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-LABEL: @dcast_per_channel_unranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform> to tensor<*xi8> +// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[STORED_TENSOR]] : tensor<*xi8> -> tensor +// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index +// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index +// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor, index) -> (tensor, tensor) +// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor -> index +// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor, index) -> (tensor, tensor) +// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor -> index + +// CHECK: %[[NUM_CHANNELS:.*]] = arith.constant 3 : index +// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[NUM_CHANNELS]], %[[SIZE_RIGHT]] : tensor<3xindex> +// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_TENSOR]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<3xindex>) -> tensor + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8> +// CHECK: %[[C_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_2]] : tensor +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[STORED_COLLAPSED]], %[[SCALES]], %[[ZERO_POINTS]] : tensor, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor) { +// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32): +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: linalg.yield %[[EXPRESSED]] : f32 +// CHECK: } -> tensor + +// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32> + +!qalias = !quant.uniform +func.func @dcast_per_channel_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> { + %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: @qcast_per_layer_scalar // CHECK-SAME: %[[ARG_0:.*]]: f32 @@ -219,7 +423,7 @@ func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> { !qalias = !quant.uniform func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias> { - %0 = "quant.qcast"(%arg0) : (tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias> + %0 = quant.qcast %arg0 : tensor<4x?x?x5xf32> to tensor<4x?x?x5x!qalias> return %0 : tensor<4x?x?x5x!qalias> } @@ -253,7 +457,7 @@ func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x !qalias = !quant.uniform:f32:1, {2.0, 3.0}> func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> { - %0 = "quant.qcast"(%arg0) : (tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> + %0 = quant.qcast %arg0 : tensor<4x2x5xf32> to tensor<4x2x5x!qalias> return %0 : tensor<4x2x5x!qalias> } @@ -301,52 +505,7 @@ func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4 !qalias = !quant.uniform func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> { - %0 = "quant.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!qalias> + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> return %0 : tensor<*x!qalias> } -// ----- - -// CHECK-LABEL: @dcast_per_layer_scalar -// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform - -// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform to i8 - -// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 -// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 -// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32 -// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 - -// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 -// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 -// CHECK: return %[[EXPRESSED]] : f32 - -!qalias = !quant.uniform -func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 { - %0 = quant.dcast %arg0 : !qalias to f32 - return %0 : f32 -} - -// ----- - -// CHECK-LABEL: @dcast_per_layer_scalar_unsigned -// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform - -// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform to i8 - -// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 -// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 - -// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32 -// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32 - -// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 -// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 -// CHECK: return %[[EXPRESSED]] : f32 - -!qalias = !quant.uniform -func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 { - %0 = quant.dcast %arg0 : !qalias to f32 - return %0 : f32 -} - From 647b8ec64d01af7606c29a0f90ad98b49ba8fa95 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Tue, 23 Jul 2024 15:48:56 -0400 Subject: [PATCH 14/22] Verifiers for 'quant.dcast' and 'quant.qcast' --- mlir/include/mlir/Dialect/Quant/IR/Quant.h | 2 + .../mlir/Dialect/Quant/IR/QuantBase.td | 96 ++++++----- .../include/mlir/Dialect/Quant/IR/QuantOps.td | 40 ++++- mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 105 +++++++++++- mlir/test/Dialect/Quant/invalid.mlir | 161 ++++++++++++++++++ mlir/test/Dialect/Quant/ops.mlir | 96 +++++++++++ 6 files changed, 443 insertions(+), 57 deletions(-) create mode 100644 mlir/test/Dialect/Quant/invalid.mlir create mode 100644 mlir/test/Dialect/Quant/ops.mlir diff --git a/mlir/include/mlir/Dialect/Quant/IR/Quant.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h index c5ca88ec69795..11a969a3ee519 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/Quant.h +++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h @@ -24,7 +24,9 @@ namespace mlir { namespace quant { +class QuantizedType; class UniformQuantizedType; +class UniformQuantizedPerAxisType; } // namespace quant } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td index e465d855c1986..8ef89a0d1393b 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td @@ -22,53 +22,63 @@ def Quant_Dialect : Dialect { let useDefaultTypePrinterParser = 1; } + //===----------------------------------------------------------------------===// -// Quantization type definitions +// Type definitions //===----------------------------------------------------------------------===// -class quant_TypedPrimitiveOrContainer : - Type.predicate, - VectorOf<[etype]>.predicate]>, - "primitive/tensor/vector of " # etype.summary>; +class quant_ScalarOrTensorOf : + Type.predicate]>, + "scalar or tensor of " # etype.summary>; -// An implementation of QuantizedType. def quant_QuantizedType : - Type($_self)">, "QuantizedType">; - -// A primitive type that can represent a real value. This is either a -// floating point value or a quantized type. -def quant_RealPrimitiveType : - Type, - "real valued primitive (float or quantized type)">; - -// A primitive type that can represent a storage value. This is either an -// integer or quantized type. -def quant_StoragePrimitiveType : - Type, - "quantized storage primitive (integer or quantized type)">; - -// A primitive or container of RealPrimitiveType. -def quant_RealValueType : - quant_TypedPrimitiveOrContainer; - -// A primitive or container of StoragePrimitiveType. -def quant_StorageValueType : - quant_TypedPrimitiveOrContainer; - -// Either a real valued or storage primitive or container type. -def quant_RealOrStorageValueType : - Type, - "real valued or storage primitive or container type">; - -// An implementation of UniformQuantizedType. -def quant_UniformQuantizedType : - DialectType($_self)">, - "UniformQuantizedType">; - -// Predicate for detecting a container or primitive of UniformQuantizedType. -def quant_UniformQuantizedValueType : - quant_TypedPrimitiveOrContainer; + Type($_self)">, "quantized type">; + +def quant_ScalarType : + Type, "integer, float, or quantized scalar">; + +def quant_IntegerOrQuantizedType : + Type>; + +def quant_FloatScalarOrTensor : + quant_ScalarOrTensorOf; + +def quant_IntegerScalarOrTensor : + quant_ScalarOrTensorOf; + +def quant_QuantizedScalarOrTensor : + quant_ScalarOrTensorOf; + +def quant_IntegerOrQuantizedScalarOrTensor : + quant_ScalarOrTensorOf; + + +//===----------------------------------------------------------------------===// +// Traits +//===----------------------------------------------------------------------===// + +def quant_SameScalarOrTensorShape : + PredOpTrait< + "input and result are both scalars or both tensors with matching shape", + Or<[ + And<[ + TypeIsPred<"input", quant_ScalarType>, + TypeIsPred<"result", quant_ScalarType> + ]>, + And<[ + TypeIsPred<"input", AnyUnrankedTensor>, + TypeIsPred<"result", AnyUnrankedTensor> + ]>, + And<[ + TypeIsPred<"input", AnyRankedTensor>, + TypeIsPred<"result", AnyRankedTensor>, + AllShapesMatch<["input", "result"]>.predicate + ]> + ]> + >; #endif // QUANT_BASE diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td index 7a6d270dbb6e9..3a24c8148d1f5 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td @@ -28,7 +28,9 @@ class quant_Op traits> : // Quantization casts //===----------------------------------------------------------------------===// -def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> { +def quant_DequantizeCastOp : quant_Op<"dcast", [ + Pure, + quant_SameScalarOrTensorShape]> { let summary = "convert back from a quantized to quantizable (expressed) type operation"; let description = [{ A DequantizeCast op `dcast` represents the inverse of a `qcast`, @@ -42,12 +44,22 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> { all operands to ops that must operate with the expressed type (typically math ops prior to lowering to target-specific, quantized kernels). }]; - let arguments = (ins quant_RealValueType:$input); - let results = (outs quant_RealValueType:$result); + let arguments = (ins quant_QuantizedScalarOrTensor:$input); + let results = (outs quant_FloatScalarOrTensor:$result); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + /// Return the float type of the scalar or tensor result. + FloatType getFloatType(); + + /// Return the quantized type of the scalar or tensor input. + quant::QuantizedType getQuantizedType(); + }]; } -def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> { +def quant_QuantizeCastOp : quant_Op<"qcast", [ + Pure, + quant_SameScalarOrTensorShape]> { let summary = "convert a quantizable type to a quantized type"; let description = [{ A QuantizeCast `qcast` represents a potential type shift from a quantizable @@ -71,12 +83,22 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> { it is legal to use a quantized representation (but is not known to be acceptable). }]; - let arguments = (ins quant_RealValueType:$input); - let results = (outs quant_RealValueType:$result); + let arguments = (ins quant_FloatScalarOrTensor:$input); + let results = (outs quant_QuantizedScalarOrTensor:$result); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + /// Return the float type of the scalar or tensor input. + FloatType getFloatType(); + + /// Return the quantized type of the scalar or tensor result. + quant::QuantizedType getQuantizedType(); + }]; } -def quant_StorageCastOp : quant_Op<"scast", [Pure]> { +def quant_StorageCastOp : quant_Op<"scast", [ + Pure, + quant_SameScalarOrTensorShape]> { let summary = "cast from or to a type based on the storage type and the corresponding quantized type"; let description = [{ A StorageCast `scast` represents a cast from or to a type based on the @@ -97,8 +119,8 @@ def quant_StorageCastOp : quant_Op<"scast", [Pure]> { vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">> ``` }]; - let arguments = (ins quant_RealOrStorageValueType:$input); - let results = (outs quant_RealOrStorageValueType:$result); + let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input); + let results = (outs quant_IntegerOrQuantizedScalarOrTensor:$result); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index e04ca7eb7e715..ef66ccf3d9e01 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/MathExtras.h" @@ -28,13 +29,70 @@ namespace quant { namespace { -Type getPrimitiveType(Type ty) { - if (auto tensorType = dyn_cast(ty)) - return tensorType.getElementType(); - return ty; +// Verify the integrity of per-axis quantization information, if present. +// +// - quantizedType +// Any quantized type. Any quantized type with no per-axis quantization is +// ignored. +// +// - containerType +// Original input or result type of the operation using the provided quantized +// type. Used to ensure that the quantized type appears within a tensor and +// that the tensor is compatible with per-axis quantization information. +// +LogicalResult verifyPerAxisQuantization(Operation *op, + QuantizedType quantizedType, + Type containerType) { + auto quantizedPerAxisType = dyn_cast(quantizedType); + if (!quantizedPerAxisType) + return success(); + + auto tensorType = dyn_cast(containerType); + if (!tensorType) + return op->emitError("scalar types may not use per-axis quantization"); + + if (!tensorType.hasRank()) + return success(); + + int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension(); + if (quantizedDimension >= tensorType.getRank()) + return op->emitError("quantized dimension must be less than tensor rank"); + + int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension); + if (quantizedDimensionSize != ShapedType::kDynamic && + quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size()) + return op->emitError( + "quantized dimension size does not match number of scales"); + + return success(); +} + +// Common verification logic for 'quant.dcast' and 'quant.qcast' ops. +// +// - quantizedType +// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'), +// whether as a primitive type or in a tensor. +// +// - floatType +// Float type used in the input ('quant.qcast') or result ('quant.dcast'), +// whether as a primitive type or in a tensor. +// +// - containerType +// Type of original input or result. +// +LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType, + FloatType floatType, Type containerType) { + if (quantizedType.getExpressedType() != floatType) + return op->emitError( + "expressed type in quantized type expected to match float type"); + + if (failed(verifyPerAxisQuantization(op, quantizedType, containerType))) + return failure(); + + return success(); } -} // namespace +} // namespace //===----------------------------------------------------------------------===// @@ -52,6 +110,24 @@ void QuantDialect::initialize() { } +//===----------------------------------------------------------------------===// +// DequantizeCastOp +//===----------------------------------------------------------------------===// + +LogicalResult DequantizeCastOp::verify() { + return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(), + getInput().getType()); +} + +FloatType DequantizeCastOp::getFloatType() { + return cast(getElementTypeOrSelf(getResult().getType())); +} + +QuantizedType DequantizeCastOp::getQuantizedType() { + return cast(getElementTypeOrSelf(getInput().getType())); +} + + //===----------------------------------------------------------------------===// // StorageCastOp //===----------------------------------------------------------------------===// @@ -65,6 +141,25 @@ OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { return srcScastOp.getInput(); } + +//===----------------------------------------------------------------------===// +// QuantizeCastOp +//===----------------------------------------------------------------------===// + +LogicalResult QuantizeCastOp::verify() { + return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(), + getInput().getType()); +} + +FloatType QuantizeCastOp::getFloatType() { + return cast(getElementTypeOrSelf(getInput().getType())); +} + +QuantizedType QuantizeCastOp::getQuantizedType() { + return cast(getElementTypeOrSelf(getResult().getType())); +} + + } // namespace quant } // namespace mlir diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir new file mode 100644 index 0000000000000..9976ce5d00d65 --- /dev/null +++ b/mlir/test/Dialect/Quant/invalid.mlir @@ -0,0 +1,161 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +func.func @dcast_invalid_input(%arg0: f32) { + // expected-error@+1 {{operand #0 must be scalar or tensor of quantized type}} + %0 = quant.dcast %arg0 : f32 to f32 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_invalid_result(%arg0: !qalias) { + // expected-error@+1 {{result #0 must be scalar or tensor of floating-point}} + %0 = quant.dcast %arg0 : !qalias to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_mismatch_scalar_tensor(%arg0: !qalias) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.dcast %arg0 : !qalias to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_mismatch_ranked_unranked_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.dcast %arg0 : tensor to tensor<*xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3x!qalias>) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_float_type_mismatch(%arg0: !qalias) { + // expected-error@+1 {{expressed type in quantized type expected to match float type}} + %0 = quant.dcast %arg0 : !qalias to f64 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_scalar(%arg0: !qalias) { + // expected-error@+1 {{scalar types may not use per-axis quantization}} + %0 = quant.dcast %arg0 : !qalias to f32 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x!qalias>) { + // expected-error@+1 {{quantized dimension must be less than tensor rank}} + %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor<2x3xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x4x!qalias>) { + // expected-error@+1 {{quantized dimension size does not match number of scales}} + %0 = quant.dcast %arg0 : tensor<2x3x4x!qalias> to tensor<2x3x4xf32> + return +} + +// ----- + +func.func @qcast_invalid_input(%arg0: f32) { + // expected-error@+1 {{result #0 must be scalar or tensor of quantized type}} + %0 = quant.qcast %arg0 : f32 to f32 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_invalid_result(%arg0: !qalias) { + // expected-error@+1 {{operand #0 must be scalar or tensor of floating-point}} + %0 = quant.qcast %arg0 : !qalias to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_mismatch_scalar_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.qcast %arg0 : tensor to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_mismatch_ranked_unranked_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.qcast %arg0 : tensor to tensor<*x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xf32>) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_float_type_mismatch(%arg0: f64) { + // expected-error@+1 {{expressed type in quantized type expected to match float type}} + %0 = quant.qcast %arg0 : f64 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_scalar(%arg0: f32) { + // expected-error@+1 {{scalar types may not use per-axis quantization}} + %0 = quant.qcast %arg0 : f32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3xf32>) { + // expected-error@+1 {{quantized dimension must be less than tensor rank}} + %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor<2x3x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3x4xf32>) { + // expected-error@+1 {{quantized dimension size does not match number of scales}} + %0 = quant.qcast %arg0 : tensor<2x3x4xf32> to tensor<2x3x4x!qalias> + return +} + + diff --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir new file mode 100644 index 0000000000000..ab3d6decfb248 --- /dev/null +++ b/mlir/test/Dialect/Quant/ops.mlir @@ -0,0 +1,96 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +!qalias = !quant.uniform +func.func @dcast_scalar(%arg0: !qalias) { + %0 = quant.dcast %arg0 : !qalias to f32 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_ranked(%arg0: tensor<2x?x4x!qalias>) { + %0 = quant.dcast %arg0 : tensor<2x?x4x!qalias> to tensor<2x?x4xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_unranked(%arg0: tensor<*x!qalias>) { + %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_static(%arg0: tensor<1x2x3x!qalias>) { + %0 = quant.dcast %arg0 : tensor<1x2x3x!qalias> to tensor<1x2x3xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_dynamic(%arg0: tensor) { + %0 = quant.dcast %arg0 : tensor to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_unranked(%arg0: tensor<*x!qalias>) { + %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_scalar(%arg0: f32) { + %0 = quant.qcast %arg0 : f32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_ranked(%arg0: tensor<2x?x4xf32>) { + %0 = quant.qcast %arg0 : tensor<2x?x4xf32> to tensor<2x?x4x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_unranked(%arg0: tensor<*xf32>) { + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_static(%arg0: tensor<1x2x3xf32>) { + %0 = quant.qcast %arg0 : tensor<1x2x3xf32> to tensor<1x2x3x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_dynamic(%arg0: tensor) { + %0 = quant.qcast %arg0 : tensor to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_unranked(%arg0: tensor<*xf32>) { + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> + return +} + From 09278de0b804824b0e6c1062e31d1a858549d23c Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Tue, 23 Jul 2024 17:56:36 -0400 Subject: [PATCH 15/22] Verifier for 'quant.scast' and unit tests --- .../mlir/Dialect/Quant/IR/QuantBase.td | 28 +++++- .../include/mlir/Dialect/Quant/IR/QuantOps.td | 11 ++- mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 65 +++++++++---- mlir/test/Dialect/Quant/invalid.mlir | 97 +++++++++++++++++++ mlir/test/Dialect/Quant/ops.mlir | 55 +++++++++++ 5 files changed, 233 insertions(+), 23 deletions(-) diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td index 8ef89a0d1393b..d81838db3dc1a 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td @@ -36,19 +36,24 @@ def quant_QuantizedType : def quant_ScalarType : Type, "integer, float, or quantized scalar">; + ]>, + "signless integer, float, or quantized scalar">; def quant_IntegerOrQuantizedType : - Type>; + Type, + "signless integer or quantized type">; def quant_FloatScalarOrTensor : quant_ScalarOrTensorOf; def quant_IntegerScalarOrTensor : - quant_ScalarOrTensorOf; + quant_ScalarOrTensorOf; def quant_QuantizedScalarOrTensor : quant_ScalarOrTensorOf; @@ -81,4 +86,19 @@ def quant_SameScalarOrTensorShape : ]> >; +def quant_IntegerAndQuantizedCombination : + PredOpTrait< + "input must be integer and result must be quantized, or vice versa", + Or<[ + And<[ + TypeIsPred<"input", quant_QuantizedScalarOrTensor>, + TypeIsPred<"result", quant_IntegerScalarOrTensor> + ]>, + And<[ + TypeIsPred<"input", quant_IntegerScalarOrTensor>, + TypeIsPred<"result", quant_QuantizedScalarOrTensor> + ]> + ]> + >; + #endif // QUANT_BASE diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td index 3a24c8148d1f5..5dab02e8e1ee5 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td @@ -98,7 +98,8 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [ def quant_StorageCastOp : quant_Op<"scast", [ Pure, - quant_SameScalarOrTensorShape]> { + quant_SameScalarOrTensorShape, + quant_IntegerAndQuantizedCombination]> { let summary = "cast from or to a type based on the storage type and the corresponding quantized type"; let description = [{ A StorageCast `scast` represents a cast from or to a type based on the @@ -122,7 +123,15 @@ def quant_StorageCastOp : quant_Op<"scast", [ let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input); let results = (outs quant_IntegerOrQuantizedScalarOrTensor:$result); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; + let hasVerifier = 1; let hasFolder = 1; + let extraClassDeclaration = [{ + /// Return the integer type used either in the input or the result. + IntegerType getIntegerType(); + + /// Return the quantized type used either in the input or the result. + quant::QuantizedType getQuantizedType(); + }]; } #endif // QUANT_OPS diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index ef66ccf3d9e01..f722eb8e30806 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -86,10 +86,8 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType, return op->emitError( "expressed type in quantized type expected to match float type"); - if (failed(verifyPerAxisQuantization(op, quantizedType, containerType))) - return failure(); - - return success(); + // Veriy integrity of per-axis quantization information, if present. + return verifyPerAxisQuantization(op, quantizedType, containerType); } } // namespace @@ -128,20 +126,6 @@ QuantizedType DequantizeCastOp::getQuantizedType() { } -//===----------------------------------------------------------------------===// -// StorageCastOp -//===----------------------------------------------------------------------===// - -OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { - // Matches x -> [scast -> scast] -> y, replacing the second scast with the - // value of x if the casts invert each other. - auto srcScastOp = getInput().getDefiningOp(); - if (!srcScastOp || srcScastOp.getInput().getType() != getType()) - return OpFoldResult(); - return srcScastOp.getInput(); -} - - //===----------------------------------------------------------------------===// // QuantizeCastOp //===----------------------------------------------------------------------===// @@ -160,6 +144,51 @@ QuantizedType QuantizeCastOp::getQuantizedType() { } +//===----------------------------------------------------------------------===// +// StorageCastOp +//===----------------------------------------------------------------------===// + +IntegerType StorageCastOp::getIntegerType() { + auto inputScalarType = getElementTypeOrSelf(getInput().getType()); + if (auto integerType = dyn_cast(inputScalarType)) + return integerType; + + auto resultScalarType = getElementTypeOrSelf(getResult().getType()); + return cast(resultScalarType); +} + +QuantizedType StorageCastOp::getQuantizedType() { + auto inputScalarType = getElementTypeOrSelf(getInput().getType()); + if (auto quantizedType = dyn_cast(inputScalarType)) + return quantizedType; + + auto resultScalarType = getElementTypeOrSelf(getResult().getType()); + return cast(resultScalarType); +} + +OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { + // Matches x -> [scast -> scast] -> y, replacing the second scast with the + // value of x if the casts invert each other. + auto srcScastOp = getInput().getDefiningOp(); + if (!srcScastOp || srcScastOp.getInput().getType() != getType()) + return OpFoldResult(); + return srcScastOp.getInput(); +} + +LogicalResult StorageCastOp::verify() { + auto quantizedType = getQuantizedType(); + auto integerType = getIntegerType(); + if (quantizedType.getStorageType() != integerType) + return emitError( + "storage type in quantized type expected to match integer type"); + + // Verify integrity of per-axis quantization information, if available. While + // the quantization type may appear in the input or the result, their tensor + // shapes are guaranteed to be identical at this point. + return verifyPerAxisQuantization(*this, quantizedType, getInput().getType()); +} + + } // namespace quant } // namespace mlir diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir index 9976ce5d00d65..ba3a8e312d96e 100644 --- a/mlir/test/Dialect/Quant/invalid.mlir +++ b/mlir/test/Dialect/Quant/invalid.mlir @@ -158,4 +158,101 @@ func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3x4xf32>) { return } +// ----- + +!qalias = !quant.uniform +func.func @scast_invalid_input(%arg0: si32) { + // expected-error@+1 {{operand #0 must be scalar or tensor of signless integer or quantized type}} + %0 = quant.scast %arg0 : si32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_invalid_result(%arg0: !qalias) { + // expected-error@+1 {{result #0 must be scalar or tensor of signless integer or quantized type}} + %0 = quant.scast %arg0 : !qalias to si32 + return +} + +// ----- + +func.func @scast_both_integers(%arg0: i8) { + // expected-error@+1 {{input must be integer and result must be quantized, or vice versa}} + %0 = quant.scast %arg0 : i8 to i8 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_both_quantized(%arg0: !qalias) { + // expected-error@+1 {{input must be integer and result must be quantized, or vice versa}} + %0 = quant.scast %arg0 : !qalias to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_mismatch_scalar_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.scast %arg0 : tensor to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_mismatch_ranked_unranked_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.scast %arg0 : tensor to tensor<*x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xi8>) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_integer_type_mismatch(%arg0: i32) { + // expected-error@+1 {{storage type in quantized type expected to match integer type}} + %0 = quant.scast %arg0 : i32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_scalar(%arg0: i8) { + // expected-error@+1 {{scalar types may not use per-axis quantization}} + %0 = quant.scast %arg0 : i8 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3xi8>) { + // expected-error@+1 {{quantized dimension must be less than tensor rank}} + %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor<2x3x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3x4xi8>) { + // expected-error@+1 {{quantized dimension size does not match number of scales}} + %0 = quant.scast %arg0 : tensor<2x3x4xi8> to tensor<2x3x4x!qalias> + return +} diff --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir index ab3d6decfb248..4abc5830d081e 100644 --- a/mlir/test/Dialect/Quant/ops.mlir +++ b/mlir/test/Dialect/Quant/ops.mlir @@ -94,3 +94,58 @@ func.func @qcast_per_axis_unranked(%arg0: tensor<*xf32>) { return } +// ----- + +!qalias = !quant.uniform +func.func @scast_scalar(%arg0: i8) { + %0 = quant.scast %arg0 : i8 to !qalias + %1 = quant.scast %0 : !qalias to i8 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_ranked(%arg0: tensor<2x?x4xi8>) { + %0 = quant.scast %arg0 : tensor<2x?x4xi8> to tensor<2x?x4x!qalias> + %1 = quant.scast %0 : tensor<2x?x4x!qalias> to tensor<2x?x4xi8> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_unranked(%arg0: tensor<*xi8>) { + %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias> + %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_static(%arg0: tensor<1x2x3xi8>) { + %0 = quant.scast %arg0 : tensor<1x2x3xi8> to tensor<1x2x3x!qalias> + %1 = quant.scast %0 : tensor<1x2x3x!qalias> to tensor<1x2x3xi8> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_dynamic(%arg0: tensor) { + %0 = quant.scast %arg0 : tensor to tensor + %1 = quant.scast %0 : tensor to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_unranked(%arg0: tensor<*xi8>) { + %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias> + %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8> + return +} + + From 764b1d5afaca8ac397d85157c54e2717fcefd0ac Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Wed, 24 Jul 2024 08:08:40 -0400 Subject: [PATCH 16/22] Canonicalization patterns for all ops and unit tests --- .../include/mlir/Dialect/Quant/IR/QuantOps.td | 2 + mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 67 ++++++--- mlir/test/Dialect/Quant/canonicalize.mlir | 134 +++++++++++++++--- 3 files changed, 164 insertions(+), 39 deletions(-) diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td index 5dab02e8e1ee5..036940119b349 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td @@ -48,6 +48,7 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [ let results = (outs quant_FloatScalarOrTensor:$result); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; let hasVerifier = 1; + let hasFolder = 1; let extraClassDeclaration = [{ /// Return the float type of the scalar or tensor result. FloatType getFloatType(); @@ -87,6 +88,7 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [ let results = (outs quant_QuantizedScalarOrTensor:$result); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; let hasVerifier = 1; + let hasFolder = 1; let extraClassDeclaration = [{ /// Return the float type of the scalar or tensor input. FloatType getFloatType(); diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index f722eb8e30806..6a709488ce01c 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -117,6 +117,17 @@ LogicalResult DequantizeCastOp::verify() { getInput().getType()); } +OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) { + // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op + // with the value of x. Values x and y are guaranteed to be of the same type + // in this pattern. + auto srcQcastOp = getInput().getDefiningOp(); + if (!srcQcastOp) + return {}; + assert(srcQcastOp.getInput().getType() == getType()); + return srcQcastOp.getInput(); +} + FloatType DequantizeCastOp::getFloatType() { return cast(getElementTypeOrSelf(getResult().getType())); } @@ -135,6 +146,18 @@ LogicalResult QuantizeCastOp::verify() { getInput().getType()); } +OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) { + // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op + // with the value of x if the casts invert each other. Contrary to the folding + // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values + // x and y are not guaranteed to be of the same type here, as they may use + // different quantization parameters. + auto srcDcastOp = getInput().getDefiningOp(); + if (!srcDcastOp || srcDcastOp.getInput().getType() != getType()) + return {}; + return srcDcastOp.getInput(); +} + FloatType QuantizeCastOp::getFloatType() { return cast(getElementTypeOrSelf(getInput().getType())); } @@ -148,6 +171,28 @@ QuantizedType QuantizeCastOp::getQuantizedType() { // StorageCastOp //===----------------------------------------------------------------------===// +LogicalResult StorageCastOp::verify() { + auto quantizedType = getQuantizedType(); + auto integerType = getIntegerType(); + if (quantizedType.getStorageType() != integerType) + return emitError( + "storage type in quantized type expected to match integer type"); + + // Verify integrity of per-axis quantization information, if available. While + // the quantization type may appear in the input or the result, their tensor + // shapes are guaranteed to be identical at this point. + return verifyPerAxisQuantization(*this, quantizedType, getInput().getType()); +} + +OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { + // Matches x -> quant.scast -> quant.scast -> y, replacing the second + // quant.scast with the value of x if the casts invert each other. + auto srcScastOp = getInput().getDefiningOp(); + if (!srcScastOp || srcScastOp.getInput().getType() != getType()) + return {}; + return srcScastOp.getInput(); +} + IntegerType StorageCastOp::getIntegerType() { auto inputScalarType = getElementTypeOrSelf(getInput().getType()); if (auto integerType = dyn_cast(inputScalarType)) @@ -166,28 +211,6 @@ QuantizedType StorageCastOp::getQuantizedType() { return cast(resultScalarType); } -OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { - // Matches x -> [scast -> scast] -> y, replacing the second scast with the - // value of x if the casts invert each other. - auto srcScastOp = getInput().getDefiningOp(); - if (!srcScastOp || srcScastOp.getInput().getType() != getType()) - return OpFoldResult(); - return srcScastOp.getInput(); -} - -LogicalResult StorageCastOp::verify() { - auto quantizedType = getQuantizedType(); - auto integerType = getIntegerType(); - if (quantizedType.getStorageType() != integerType) - return emitError( - "storage type in quantized type expected to match integer type"); - - // Verify integrity of per-axis quantization information, if available. While - // the quantization type may appear in the input or the result, their tensor - // shapes are guaranteed to be identical at this point. - return verifyPerAxisQuantization(*this, quantizedType, getInput().getType()); -} - } // namespace quant } // namespace mlir diff --git a/mlir/test/Dialect/Quant/canonicalize.mlir b/mlir/test/Dialect/Quant/canonicalize.mlir index 36c3eaf5e10d2..73c57e2a48212 100644 --- a/mlir/test/Dialect/Quant/canonicalize.mlir +++ b/mlir/test/Dialect/Quant/canonicalize.mlir @@ -1,24 +1,124 @@ // RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s +// CHECK-LABEL: @dcast_fold +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: return %[[ARG_0]] + +!qalias = !quant.uniform +func.func @dcast_fold(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias> + %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32> + return %1 : tensor<4xf32> +} + // ----- -// CHECK-LABEL: redundant_scast -func.func @redundant_scast() -> tensor<4xi8> { - // CHECK-NEXT: arith.constant dense<10> : tensor<4xi8> - // CHECK-NEXT: return - %cst = arith.constant dense<5> : tensor<4xi8> - %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform> - %2 = "quant.scast"(%1) : (tensor<4x!quant.uniform>) -> tensor<4xi8> - %3 = arith.addi %2, %2 : tensor<4xi8> - return %3 : tensor<4xi8> + +// CHECK-LABEL: @dcast_no_fold_source +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]] +// CHECK: %[[VAL_1:.*]] = quant.dcast %[[VAL_0]] +// CHECK: return %[[VAL_1]] + +!qalias = !quant.uniform +func.func @dcast_no_fold_source(%arg0: tensor<4xi8>) -> tensor<4xf32> { + %0 = quant.scast %arg0 : tensor<4xi8> to tensor<4x!qalias> + %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32> + return %1 : tensor<4xf32> } // ----- -// CHECK-LABEL: non_redundant_scast -func.func @non_redundant_scast() -> tensor<4x!quant.uniform> { - // CHECK-NEXT: arith.constant dense<5> : tensor<4xi8> - // CHECK-NEXT: scast - // CHECK-NEXT: return - %cst = arith.constant dense<5> : tensor<4xi8> - %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform> - return %1 : tensor<4x!quant.uniform> + +// CHECK-LABEL: @qcast_fold +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: return %[[ARG_0]] + +!qalias = !quant.uniform +func.func @qcast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> { + %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32> + %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias> + return %1 : tensor<4x!qalias> } + +// ----- + +// CHECK-LABEL: @qcast_no_fold_source +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[VAL_0:.*]] = arith.negf %[[ARG_0]] +// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]] +// CHECK: return %[[VAL_1]] + +!qalias = !quant.uniform +func.func @qcast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4x!qalias> { + %0 = arith.negf %arg0 : tensor<4xf32> + %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias> + return %1 : tensor<4x!qalias> +} + +// ----- + +// CHECK-LABEL: @qcast_no_fold_type +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[VAL_0:.*]] = quant.dcast %[[ARG_0]] +// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]] +// CHECK: return %[[VAL_1]] + +!qalias = !quant.uniform +!qalias1 = !quant.uniform +func.func @qcast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> { + %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32> + %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias1> + return %1 : tensor<4x!qalias1> +} + +// ----- + +// CHECK-LABEL: @scast_fold +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: return %[[ARG_0]] + +!qalias = !quant.uniform +func.func @scast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> { + %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8> + %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias> + return %1 : tensor<4x!qalias> +} + +// ----- + +// CHECK-LABEL: @scast_no_fold_source +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[QCAST:.*]] = quant.qcast %[[ARG_0]] +// CHECK: %[[SCAST:.*]] = quant.scast %[[QCAST]] +// CHECK: return %[[SCAST]] + +!qalias = !quant.uniform +func.func @scast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4xi8> { + %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias> + %1 = quant.scast %0 : tensor<4x!qalias> to tensor<4xi8> + return %1 : tensor<4xi8> +} + +// ----- + +// CHECK-LABEL: @scast_no_fold_type +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]] +// CHECK: %[[VAL_1:.*]] = quant.scast %[[VAL_0]] +// CHECK: return %[[VAL_1]] + +!qalias = !quant.uniform +!qalias1 = !quant.uniform +func.func @scast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> { + %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8> + %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias1> + return %1 : tensor<4x!qalias1> +} + From 70c5af961b01b91b7d1664ebad7fa528a37b35af Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Thu, 25 Jul 2024 00:23:28 -0400 Subject: [PATCH 17/22] Pass 'strip-func-quant-types' --- .../mlir/Dialect/Quant/Transforms/Passes.td | 15 +++ .../Dialect/Quant/Transforms/CMakeLists.txt | 1 + .../Quant/Transforms/StripFuncQuantTypes.cpp | 114 ++++++++++++++++++ .../Dialect/Quant/strip-func-quant-types.mlir | 88 ++++++++++++++ 4 files changed, 218 insertions(+) create mode 100644 mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp create mode 100644 mlir/test/Dialect/Quant/strip-func-quant-types.mlir diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td index 56e10688b0c98..b25296d4db5a9 100644 --- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td @@ -31,4 +31,19 @@ def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> { ]; } +def StripFuncQuantTypes : Pass<"strip-func-quant-types"> { + let summary = "Strip quantized types from function headers"; + let description = [{ + Identify occurrences of function arguments using a quantized type and + replace them with a new value of the corresponding storage (signless + integer) type. For each converted argument, a `quant.scast` op is introduced + at the head of the function's entry block converting the new integer + argument into the original quantized value. + }]; + let dependentDialects = [ + "func::FuncDialect", + "quant::QuantDialect" + ]; +} + #endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt index 2daea7750cfe3..662c3e368b624 100644 --- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRQuantTransforms LowerQuantOps.cpp + StripFuncQuantTypes.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp new file mode 100644 index 0000000000000..8996eff61a39c --- /dev/null +++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp @@ -0,0 +1,114 @@ +//===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Strips quantized types from function headers. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace quant { + +#define GEN_PASS_DEF_STRIPFUNCQUANTTYPES +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +namespace { + +class QuantizedTypeConverter : public TypeConverter { + + static Type convertQuantizedType(QuantizedType quantizedType) { + return quantizedType.getStorageType(); + } + + static Type convertTensorType(TensorType tensorType) { + if (auto quantizedType = dyn_cast(tensorType.getElementType())) + return tensorType.clone(convertQuantizedType(quantizedType)); + return tensorType; + } + + static Value materializeConversion(OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + return builder.create(loc, type, inputs[0]); + } + +public: + + explicit QuantizedTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion(convertQuantizedType); + addConversion(convertTensorType); + + addArgumentMaterialization(materializeConversion); + addSourceMaterialization(materializeConversion); + addTargetMaterialization(materializeConversion); + } +}; + +// Conversion pass +class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase { + + // Return whether a type is considered legal when occurring in the header of + // a function or as an operand to a 'return' op. + static bool isLegalType(Type type) { + if (auto tensorType = dyn_cast(type)) + return isLegalType(tensorType.getElementType()); + return !isa(type); + } + +public: + + void runOnOperation() override { + + auto moduleOp = cast(getOperation()); + auto* context = &getContext(); + + QuantizedTypeConverter typeConverter; + ConversionTarget target(*context); + RewritePatternSet patterns(context); + + // Mark func.func, func.return, and func.call illegal if they contain any + // quantized types. + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + + // Register conversion patterns + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + + // Apply conversion + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +} // namespace quant +} // namespace mlir + diff --git a/mlir/test/Dialect/Quant/strip-func-quant-types.mlir b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir new file mode 100644 index 0000000000000..e5f0d4921bed3 --- /dev/null +++ b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s --strip-func-quant-types --split-input-file | FileCheck %s + +// CHECK-LABEL: @strip_operands +// CHECK-SAME: %[[ARG_0:.*]]: i8 +// CHECK-SAME: %[[ARG_1:.*]]: i16 +// CHECK-SAME: %[[ARG_2:.*]]: f32 + +// CHECK: %[[ARG_0_CAST:.*]] = quant.scast %[[ARG_1]] : i16 to !quant.uniform<{{.*}}> +// CHECK: %[[ARG_1_CAST:.*]] = quant.scast %[[ARG_0]] : i8 to !quant.uniform<{{.*}}> + +// CHECK: "test.custom_op"(%[[ARG_1_CAST]]) +// CHECK: "test.custom_op"(%[[ARG_0_CAST]]) +// CHECK: "test.custom_op"(%[[ARG_2]]) + +!qalias = !quant.uniform +!qalias1 = !quant.uniform + +func.func @strip_operands(%arg0: !qalias, %arg1: !qalias1, %arg2: f32) { + "test.custom_op"(%arg0) : (!qalias) -> tensor<4x!qalias> + "test.custom_op"(%arg1) : (!qalias1) -> tensor + "test.custom_op"(%arg2) : (f32) -> tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @strip_results +// CHECK-SAME: tensor<4xi8>, tensor, tensor<*xi8>, tensor<4xf32> + +// CHECK: %[[RESULT_0:.*]] = "test.custom_op"() +// CHECK: %[[RESULT_CAST_0:.*]] = quant.scast %[[RESULT_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8> + +// CHECK: %[[RESULT_1:.*]] = "test.custom_op"() +// CHECK: %[[RESULT_CAST_1:.*]] = quant.scast %[[RESULT_1]] : tensor> to tensor + +// CHECK: %[[RESULT_2:.*]] = "test.custom_op"() +// CHECK: %[[RESULT_CAST_2:.*]] = quant.scast %[[RESULT_2]] : tensor<*x!quant.uniform<{{.*}}>> to tensor<*xi8> + +// CHECK: %[[RESULT_3:.*]] = "test.custom_op"() + +// CHECK: return %[[RESULT_CAST_0]], %[[RESULT_CAST_1]], %[[RESULT_CAST_2]], %[[RESULT_3]] + +!qalias = !quant.uniform +!qalias1 = !quant.uniform + +func.func @strip_results() -> (tensor<4x!qalias>, tensor, tensor<*x!qalias>, tensor<4xf32>) { + %0 = "test.custom_op"() : () -> tensor<4x!qalias> + %1 = "test.custom_op"() : () -> tensor + %2 = "test.custom_op"() : () -> tensor<*x!qalias> + %3 = "test.custom_op"() : () -> tensor<4xf32> + return %0, %1, %2, %3 : tensor<4x!qalias>, tensor, tensor<*x!qalias>, tensor<4xf32> +} + +// ----- + + +// CHECK-LABEL: @callee +// CHECK-SAME: (tensor<4xi8>, tensor) -> (tensor<*xi8>, tensor<4xf32>) + +// CHECK-LABEL: @strip_call + +// CHECK: %[[OPERAND_0:.*]] = "test.custom_op"() +// CHECK: %[[OPERAND_0_CAST:.*]] = quant.scast %[[OPERAND_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8> + +// CHECK: %[[OPERAND_1:.*]] = "test.custom_op"() +// CHECK: %[[OPERAND_1_CAST:.*]] = quant.scast %[[OPERAND_1]] : tensor> to tensor + +// CHECK: %[[RESULTS:.*]]:2 = call @callee(%[[OPERAND_0_CAST]], %[[OPERAND_1_CAST]]) + +// CHECK: %[[RESULT_0_CAST:.*]] = quant.scast %[[RESULTS]]#0 : tensor<*xi8> to tensor<*x!quant.uniform<{{.*}}>> +// CHECK: "test.custom_op"(%[[RESULT_0_CAST]]) + +// CHECK: "test.custom_op"(%[[RESULTS]]#1) + +// CHECK: return + +!qalias = !quant.uniform +!qalias1 = !quant.uniform + +func.func private @callee(tensor<4x!qalias>, tensor) -> (tensor<*x!qalias>, tensor<4xf32>) + +func.func @strip_call() { + %0 = "test.custom_op"() : () -> tensor<4x!qalias> + %1 = "test.custom_op"() : () -> tensor + %2:2 = func.call @callee(%0, %1) : (tensor<4x!qalias>, tensor) -> (tensor<*x!qalias>, tensor<4xf32>) + "test.custom_op"(%2#0) : (tensor<*x!qalias>) -> () + "test.custom_op"(%2#1) : (tensor<4xf32>) -> () + return +} From cce8171c6d016d823e514ec304f94d2e8c4085c0 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Thu, 25 Jul 2024 18:43:30 -0400 Subject: [PATCH 18/22] Dialect documentation progress --- .../mlir/Dialect/Quant/IR/QuantBase.td | 53 +++++- .../include/mlir/Dialect/Quant/IR/QuantOps.td | 178 +++++++++++++----- 2 files changed, 186 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td index d81838db3dc1a..2690b4fe0b111 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td @@ -17,8 +17,59 @@ include "mlir/IR/OpBase.td" def Quant_Dialect : Dialect { let name = "quant"; + let description = [{ + ## Per-axis quantization integrity + + When type `!quant.uniform` contains per-axis quantization information, the + rules below are enforced. These rules guarantee that the quantization + information encoded in the data type is applicable to the context in which + the quantized type is used. For efficiency, these rules are actively + enforced by the verifiers of `quant` dialect ops, but they must be + respected in any context in which the `!quant.uniform` data type is used, + such as the header of a `func.func` op, or the input of an arithmetic + operation. + + - A quantized type with per-channel quantization information must be the + element type of a tensor container type, and may not occur directly as + the data type of a scalar value. + + ``` + // Incorrect. Type !quant.uniform specifies per-channel quantization for a + // scalar type. + %result = quant.qcast %input : f32 to !quant.uniform + + // Correct. Type `!quant.uniform` with per-channel quantization is wrapped in + // a `tensor` type. + %result = quant.qcast %input : tensor<2xf32> to tensor<2x!quant.uniform> + ``` + + - If the tensor containing the `!quant.uniform` type is ranked, its rank + must be greater than the channel axis specified in the quantized type. + + ``` + // Incorrect. The tensor rank (2) is not greater than the channel axis in the + // quantized type (3). + %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform> + + // Correct. The tensor rank (2) is now greater than the channel axis (1): + %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform> + ``` + + - If the axis dimension in the containing tensor is static, its size must + be equal to the number of scales present in the quantized type. + + ``` + // Incorrect. The channel axis is 1, and the size of dimension 1 in the + // containing tensor is 3. However, there are 4 scale values present in the + // quantized type. + %result = quant.qcast %input : tensor to tensor> + + // Correct. The quantized type now includes 3 scale values, matching the size + // of dimension 1 of the result tensor. + %result = quant.qcast %input : tensor to tensor> + ``` + }]; let cppNamespace = "::mlir::quant"; - let useDefaultTypePrinterParser = 1; } diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td index 036940119b349..52dfc6b051de7 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td @@ -31,18 +31,55 @@ class quant_Op traits> : def quant_DequantizeCastOp : quant_Op<"dcast", [ Pure, quant_SameScalarOrTensorShape]> { - let summary = "convert back from a quantized to quantizable (expressed) type operation"; + let summary = "Dequantize cast operation"; let description = [{ - A DequantizeCast op `dcast` represents the inverse of a `qcast`, - converting back from a quantized to quantizable (expressed) type. + Convert an input quantized value into its expressed floating-point value. + The dequantization process consists of the following steps: - Like `qcast`s, a `dcast` is allowed to have both its operand and result - as non quantized types. This facilitates transformations and marks edges - where the computation must be carried out in the expressed type. + ``` + def dequantize(quantizedValue: quantizedType) -> expressedType: + storedValue = reinterpretCast(quantizedValue, storageType) + storedValueFloat = convertIntToFloat(storedValue, expressedType) + zeroPointFloat = convertIntToFloat(zeroPoint, expressedType) + expressedValue = (storedValueFloat - zeroPointFloat) * scale + return expressedValue + ``` + + Here, `storageType`, `expressedType`, `scale`, and `zeroPoint` are obtained + from the corresponding parameters encoded in `quantizedType`. For + per-channel quantization, the appropriate `scale` and `zeroPoint` values + are used for each tensor element computation according to the channel the + element belongs to. + + The operation must satisfy the following syntactic constraints: + + - Operand `input` must be a scalar or tensor of type `!quant.uniform`. + + - The result type must be a floating-point scalar or tensor. + + - The `expressedType` parameter of the `!quant.uniform` type of the input + must match the floating-point type of the result. + + - The operand and result types must be both scalars or both tensors. If + tensors, they must be both ranked or both unranked. If ranked, both must + have the same shape, including matching static and dynamic dimensions. + + - If the operand uses per-channel quantization, its `!quant.uniform` type + must adhere to the [Per-axis quantization + integrity](#per-axis-quantization-integrity) guidelines. + + Examples: + + ``` + // Dequantize a scalar quantized value + %result = quant.dcast %input : !quant.uniform to f32 + + // Dequantize a dynamically shaped tensor of quantized values + %result = quant.dcast %input : tensor> to tensor - Especially early in transformation, it is common to have `dcast`s on - all operands to ops that must operate with the expressed type (typically - math ops prior to lowering to target-specific, quantized kernels). + // Dequantize an unranked tensor using per-axis quantization information + %result = quant.dcast %input : tensor<*x!quant.uniform> to tensor<*xf32> + ``` }]; let arguments = (ins quant_QuantizedScalarOrTensor:$input); let results = (outs quant_FloatScalarOrTensor:$result); @@ -61,28 +98,57 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [ def quant_QuantizeCastOp : quant_Op<"qcast", [ Pure, quant_SameScalarOrTensorShape]> { - let summary = "convert a quantizable type to a quantized type"; + let summary = "Quantize cast operation"; let description = [{ - A QuantizeCast `qcast` represents a potential type shift from a quantizable - type to a quantized type. - - At runtime, a `qcast` will apply the transformation expressed by its - operand and result type. For flexibility during transformation, it is also - possible to have a `qcast` that performs no transformation (both its - operand and result type are quantizable). - - A `qcast` will typically originate from either: - a) An expressed or implied constraint in the source dialect which signals - that a certain level of quantization is possible or required. - b) An inference made by a quantization algorithm indicating that a - quantized representation may be acceptable. - - Especially early in transformation, it is common to have pairs of - `qcast` and `dcast` at points where a transition to a quantized type is - required. In addition, it is also common to have an identity `qcast` - (where the operand and result type are not quantized) at all points where - it is legal to use a quantized representation (but is not known to be - acceptable). + Convert a floating-point value to a quantized type. The quantization + process consists of the following steps: + + ``` + def quantize(expressedValue: expressedType) -> quantizedType: + zeroPointFloat = convertIntToFloat(zeroPoint, expressedType) + scaledValue = expressedValue / scale + storedValueFloat = scaledValue + zeroPointFloat + storedValue = convertFloatToInt(storedValueFloat, storageType) + storedValueClamped = clamp(storedValue, storageMin, storageMax) + quantizedValue = reinterpretCast(storedValueClamped, quantizedType) + return quantizedValue + ``` + + Here, `storageType`, `storageMin`, `storageMax`, `expressedType`, `scale`, + and `zeroPoint` are obtained from the corresponding parameters encoded in + `quantizedType`. For per-channel quantization, the appropriate `scale` and + `zeroPoint` values are used for each tensor element computation according + to the channel the element belongs to. + + The operation must satisfy the following syntactic constraints: + + - Operand `input` must be a floating-point scalar or tensor. + + - The result type must be a scalar or tensor of type `!quant.uniform`. + + - The `expressedType` parameter in the `!quant.uniform` type of the result + must match the floating-point type of the input. + + - The operand and result types must be both scalars or both tensors. If + tensors, they must be both ranked or both unranked. If ranked, both must + have the same shape, including matching static and dynamic dimensions. + + - If the result uses per-channel quantization, its `!quant.uniform` type + must adhere to the [Per-axis quantization + integrity](#per-axis-quantization-integrity) guidelines. + + Examples: + + ``` + // Quantize a scalar floating-point value + %result = quant.qcast %input : f32 to !quant.uniform + + // Quantize a dynamically shaped tensor of quantized values + %result = quant.qcast %input : tensor to tensor> + + // Quantize an unranked tensor using per-axis quantization information + %result = quant.qcast %input : tensor<*xf32> to tensor<*x!quant.uniform> + ``` }]; let arguments = (ins quant_FloatScalarOrTensor:$input); let results = (outs quant_QuantizedScalarOrTensor:$result); @@ -102,24 +168,48 @@ def quant_StorageCastOp : quant_Op<"scast", [ Pure, quant_SameScalarOrTensorShape, quant_IntegerAndQuantizedCombination]> { - let summary = "cast from or to a type based on the storage type and the corresponding quantized type"; + let summary = "Storage cast operation"; let description = [{ - A StorageCast `scast` represents a cast from or to a type based on the - storage type and a type based on a corresponding quantized type. + Convert a value from a quantized type to the corresponding signless integer + storage type, or vice versa. This conversion simply involves a + reinterpretation of the input bits and does not involve any data + manipulation. - This op exists to ensure type coherency for between parts of the computation - which are operating directly on an underlying storage type and those which - operate on quantized values. + The following syntactic restrictions must be met: + + - Operand `input` must be a scalar or tensor of a signless integer or + `!quant.uniform` type. + + - The result must be a scalar or tensor of a signless integer or + `!quant.uniform` type. + + - If the operand is a scalar or tensor of type integer, the result must be + a scalar or tensor of type `!quant.uniform`, and vice versa. + + - The operand and result must be both scalars or both tensors. If tensors, + they must be both ranked or both unranked. If ranked, both must have the + same shape, including matching static and dynamic dimensions. + + - The width of the `storageType` parameter of the quantized type of the + operand or result must match the width of the signless integer type of + the operand or result. + + - If the operand or result uses per-channel quantization, its + `!quant.uniform` type must adhere to the [Per-axis quantization + integrity](#per-axis-quantization-integrity) guidelines. + + Examples: - Examples from storage to quantized type: - ``` - i8 -> !quant<"uniform[i8:f32]{1.0}"> - ``` - ``` - tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - ``` ``` - vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">> + // Cast a scalar quantized value into its storage type + %result = quant.scast %input : !quant.uniform to i8 + + // Cast a dynamically shaped tensor of quantized values into their storage type + %result = quant.scast %input : tensor> to tensor + + // Cast an unranked tensor of signless integers into a quantized type using + // per-channel quantization + %result = quant.scast %input : tensor<*xi8> to tensor<*x!quant.uniform> ``` }]; let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input); From 8a990c8de1ce8ef39d8cdca0fcf8d89b00b469a0 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Fri, 26 Jul 2024 12:40:07 -0400 Subject: [PATCH 19/22] Remaining documentation for the 'quant' dialect, including the 'quant.uniform' data type --- .../mlir/Dialect/Quant/IR/QuantBase.td | 154 +++++++++++++++++- 1 file changed, 148 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td index 2690b4fe0b111..0a014b564b058 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td @@ -18,6 +18,148 @@ include "mlir/IR/OpBase.td" def Quant_Dialect : Dialect { let name = "quant"; let description = [{ + The `quant` dialect offers a framework for defining and manipulating + quantized values. Central to this framework is the `!quant.uniform` data + type, used to represent quantized values. This dialect also provides a + suite of operations to handle and convert quantized values between their + original floating-point representations and the optimized, lower bit-width + integer representations. The `quant` dialect is instrumented with + transformation passes to lower these operations into other core MLIR + dialects, while also flattening all occurrences of quantized types into + their integer counterparts. + + + ## The `!quant.uniform` type + + The quantization process establishes a relationship between two types of + values: an *expressed value* and a *stored value*. The former refers to the + floating-point representation used in an original machine learning model, + capturing the precise numerical characteristics needed for accurate + calculations. The latter is the simplified integer representation that + resides in memory after quantization. The `!quant.uniform` data type + encodes the necessary information for (lossy) round-trip conversion between + an expressed and a stored value. + + The `quant.uniform` type has two variants: per-layer quantization and + per-channel (or per-axis) quantization. In per-layer quantization, the + quantization information affects an entire tensor uniformly. Conversely, in + per-channel quantization, the data type encodes the specific tensor axis + that serves as the channel and includes quantization information for each + individual channel within the tensor. Below are the specific syntactic and + semantic considerations for each modality. + + + ### Per-layer quantization + + This is the general syntax of the `!quant.uniform` type representing + per-layer quantization: + + ``` + `!quant.uniform` `<` + storedType (`<` storageMin `:` storageMax `>`)? `:` + expressedType `,` + scale (`:` zeroPoint)? + `>` + ``` + + The type contains the following parameters: + + - `storedType`: Integer type of the value stored in memory. This type + conveys the bit width and signedness of the quantized stored value. + Signed integer types are represented as `'i' bitWidth` (e.g., `i8`), + while unsigned integer types are represented as `'u' bitWidth` (e.g., + `u8`). + + - `storageMin`, `storageMax`: Optional bounds for the stored value. If + given, they must be within the range of `storedType`. If omitted, the + entire range of `storedType` is allowed (e.g., `-128...127` for `i8` or + `0...255` for `u8`). + + - `expressedType`: Floating-point type of the value expressed by this + quantized type. + + - `scale`: Floating-point value of type `expressedType` used in the + conversion between stored and expressed values. + + - `zeroPoint`: Optional integer value of type `storageType` used in the + conversion between stored and expressed values. If omitted, the default + is 0. + + Type conversions, rounding methods, and clamping actions aside, the + relationship between the expressed and stored values as encoded in a + quantized type is denoted by the following formula: + + $$ + expressedValue = (storedValue ~-~ zeroPoint) ~\times~ scale + $$ + + Operations `quant.qcast` (quantize cast) and `quant.dcast` (dequantize + cast) can be used to quantize a floating-point value and dequantize a + stored value, respectively. See the documentation for these operations for + details on how the quantization and dequantization processes are influenced + by the `!quant.uniform` type parameters. + + Here are some examples of the use of `!quant.uniform` with per-layer + quantization: + + ``` + // An 8-bit signed integer type is used to represent a 32-bit float. No + // clamping information is provided, so the full [-128, 127] range is + // available. The scale is set to 3.0, and the zero point takes its default + // 0 value. + !quant.uniform + + // A 16-bit unsigned integer type is used to represent a 32-bit float. Out + // of the 16 bits, only 10 are used, acoording to the 0..1023 clamping + // range. The type sets the scale to 1.23 and the zero point to 512. + !quant.uniform:f32, 1.23:512> + ``` + + ### Per-channel quantization + + The general syntax of the `!quant.uniform` type representing per-channel + quantization is as follows: + + ``` + `!quant.uniform` `<` + storedType (`<` storageMin `:` storageMax `>`)? `:` + expressedType `:` + channelAxis `,` + `{` + scale0 (`:` zeroPoint0)? `,` + scale1 (`:` zeroPoint1)? ... + '}' + `>` + ``` + + In this data type, there are multiple pairs of `scale` and `zeroPoint` + values. The `channelAxis` field represents the dimension of the containing + tensor acting as the channel. The size of the tensor along this dimension + is expected to match the number of provided `scale`-`zeroPoint` pairs, and + a given pair *i* applies to all elements in the tensor whose index along + dimension `channelAxis` is *i*. A quantized data type using per-channel + quantization is always expected to be contained within a tensor type. + + Here are some examples: + + ``` + // A 2x3x4 tensor contains 8-bit signed integers representing 32-bit + // floats. Dimension 1 of the tensor acts as the channel dimension. Its + // size 3 matches the number of provided scale values. Tensor elemenets at + // positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and + // 5.0, respectively. + tensor<2x3x4x!quant.uniform> + + // A 2D dynamically sized tensor contains 16-bit unsigned integers + // representing 32-bit floats. Dimension 0 of the tensor acts as the + // channel dimension. Since 2 scale and zero-point values are provided, the + // size of dimension 0 is expected to be 2 at runtime. Tensor elements + // [0][*] use scale 2.0 and zero point 10, while elements [1][*] use scale + // 3.0 and zero point 20. + tensor> + ``` + + ## Per-axis quantization integrity When type `!quant.uniform` contains per-axis quantization information, the @@ -38,8 +180,8 @@ def Quant_Dialect : Dialect { // scalar type. %result = quant.qcast %input : f32 to !quant.uniform - // Correct. Type `!quant.uniform` with per-channel quantization is wrapped in - // a `tensor` type. + // Correct. Type `!quant.uniform` with per-channel quantization is wrapped + // in a `tensor` type. %result = quant.qcast %input : tensor<2xf32> to tensor<2x!quant.uniform> ``` @@ -47,8 +189,8 @@ def Quant_Dialect : Dialect { must be greater than the channel axis specified in the quantized type. ``` - // Incorrect. The tensor rank (2) is not greater than the channel axis in the - // quantized type (3). + // Incorrect. The tensor rank (2) is not greater than the channel axis in + // the quantized type (3). %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform> // Correct. The tensor rank (2) is now greater than the channel axis (1): @@ -64,8 +206,8 @@ def Quant_Dialect : Dialect { // quantized type. %result = quant.qcast %input : tensor to tensor> - // Correct. The quantized type now includes 3 scale values, matching the size - // of dimension 1 of the result tensor. + // Correct. The quantized type now includes 3 scale values, matching the + // size of dimension 1 of the result tensor. %result = quant.qcast %input : tensor to tensor> ``` }]; From a48137b51afc2df1da245f5448d91c4be91ad1a2 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Tue, 30 Jul 2024 15:22:52 -0400 Subject: [PATCH 20/22] Added link-time dependencies --- mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt index 662c3e368b624..2fd4a41999d45 100644 --- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt @@ -9,8 +9,18 @@ add_mlir_dialect_library(MLIRQuantTransforms MLIRQuantTransformsIncGen LINK_LIBS PUBLIC - MLIRQuantDialect + MLIRArithDialect + MLIRFuncDialect + MLIRFuncTransforms + MLIRIndexDialect + MLIRIR + MLIRLinalgDialect + MLIRLinalgUtils MLIRPass + MLIRQuantDialect + MLIRShapeDialect + MLIRTensorDialect MLIRTransforms MLIRTransformUtils + ) From 320d5a24c9be4a76fb1444d4f735427f2b928cee Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Tue, 27 Aug 2024 14:44:28 -0400 Subject: [PATCH 21/22] Addressing review feedback --- mlir/include/mlir/Dialect/Quant/IR/QuantBase.td | 6 +++--- mlir/include/mlir/Dialect/Quant/IR/QuantOps.td | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td index 0a014b564b058..791cb9de48d05 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// Predicates for types in the Quantization dialect. +// Quantization dialect, types, and traits. // //===----------------------------------------------------------------------===// @@ -76,7 +76,7 @@ def Quant_Dialect : Dialect { `0...255` for `u8`). - `expressedType`: Floating-point type of the value expressed by this - quantized type. + quantized type (e.g., `f32`, `f80`, `bf16`, or `tf32`). - `scale`: Floating-point value of type `expressedType` used in the conversion between stored and expressed values. @@ -217,7 +217,7 @@ def Quant_Dialect : Dialect { //===----------------------------------------------------------------------===// -// Type definitions +// Type predicates //===----------------------------------------------------------------------===// class quant_ScalarOrTensorOf : diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td index 52dfc6b051de7..6ef925146dce6 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td @@ -50,6 +50,13 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [ per-channel quantization, the appropriate `scale` and `zeroPoint` values are used for each tensor element computation according to the channel the element belongs to. + + The numerical results produced by the algorithm above may vary depending on + the rounding methods used by `convertIntToFloat()`, subtraction (`-`), and + multiplication (`*`). This operation does not define specific rounding + methods; instead, it is the responsibility of a transform pipeline to + determine which rounding method to apply when this operation is broken down + into lower-level dialects. The operation must satisfy the following syntactic constraints: @@ -120,6 +127,13 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [ `zeroPoint` values are used for each tensor element computation according to the channel the element belongs to. + The numerical results produced by the algorithm above may vary depending on + the rounding methods used by `convertIntToFloat()`, `convertFloatToInt()`, + `clamp()`, division (`/`), and addition (`+`). This operation does not + define specific rounding methods; instead, it is the responsibility of a + transform pipeline to determine which rounding method to apply when this + operation is broken down into lower-level dialects. + The operation must satisfy the following syntactic constraints: - Operand `input` must be a floating-point scalar or tensor. From 9fe55fb0a3a10f8ec1bbdd2027a7b540927ab487 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Thu, 26 Sep 2024 13:43:20 -0400 Subject: [PATCH 22/22] Removed unnecessary includes. Addressed feedback regarding additional type integrity checks. --- mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 6 ---- mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 30 +++++++++++++++++++ .../Dialect/Quant/parse-uniform-invalid.mlir | 25 ++++++++++++++++ 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index 6a709488ce01c..c584903f3a15d 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -12,14 +12,8 @@ #include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Support/MathExtras.h" -#include #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc" diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 9ef774ee704ed..ac01b37a55307 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -20,6 +20,22 @@ using namespace mlir; using namespace mlir::quant; using namespace mlir::quant::detail; +namespace { + +// Return the minimum scale representable in a given float type +double getMinScale(Type expressedType) { + auto floatType = cast(expressedType); + return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble(); +} + +// Return the maximum scale representable in a given float type +double getMaxScale(Type expressedType) { + auto floatType = cast(expressedType); + return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble(); +} + +} // namespace + unsigned QuantizedType::getFlags() const { return static_cast(impl)->flags; } @@ -304,8 +320,13 @@ LogicalResult UniformQuantizedType::verifyInvariants( return emitError() << "expressed type must be floating point"; // Verify scale. + double minScale = getMinScale(expressedType); + double maxScale = getMaxScale(expressedType); if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) return emitError() << "illegal scale: " << scale; + if (scale < minScale || scale > maxScale) + return emitError() << "scale out of expressed type range [" << minScale + << ", " << maxScale << "]"; return success(); } @@ -364,11 +385,20 @@ LogicalResult UniformQuantizedPerAxisType::verifyInvariants( << scales.size() << ", " << zeroPoints.size(); // Verify scale. + double minScale = getMinScale(expressedType); + double maxScale = getMaxScale(expressedType); for (double scale : scales) { if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) return emitError() << "illegal scale: " << scale; + if (scale < minScale || scale > maxScale) + return emitError() << "scale out of expressed type range [" << minScale + << ", " << maxScale << "]"; } + // Verify quantized dimension. + if (quantizedDimension < 0) + return emitError() << "illegal quantized dimension: " << quantizedDimension; + return success(); } diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir index a82e8efdb1a3c..7613a344cf2b8 100644 --- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir +++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir @@ -120,3 +120,28 @@ // provided. // expected-error@+1 {{expected floating point literal}} !qalias = !quant.uniform:f32, {2.000000e+02,-19.987200e-01:1}> + +// ----- +// Illegal negative axis in per-axis quantization +// expected-error@+1 {{illegal quantized dimension: -1}} +!qalias = !quant.uniform + +// ----- +// Scale f16 underflow +// expected-error@+1 {{scale out of expressed type range}} +!qalias = !quant.uniform + +// ----- +// Scale f16 overflow +// expected-error@+1 {{scale out of expressed type range}} +!qalias = !quant.uniform + +// ----- +// Scale f16 underflow in per-axis quantization +// expected-error@+1 {{scale out of expressed type range}} +!qalias = !quant.uniform + +// ----- +// Scale f16 overflow in per-axis quantization +// expected-error@+1 {{scale out of expressed type range}} +!qalias = !quant.uniform