Skip to content

[mlir] Improvements to the 'quant' dialect #100667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Sep 26, 2024

Conversation

rafaelubalmw
Copy link
Contributor

Full revamp of the 'quant' dialect. This is an implementation for the RFC at https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942

@rafaelubalmw rafaelubalmw requested a review from sabauma July 25, 2024 22:47
Copy link

github-actions bot commented Jul 25, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 439dcfafc5af3e018a80e8112bc515249e1cbfbc 9fe55fb0a3a10f8ec1bbdd2027a7b540927ab487 --extensions cpp,h -- mlir/include/mlir/Dialect/Quant/Transforms/Passes.h mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h mlir/include/mlir/InitAllDialects.h mlir/include/mlir/InitAllPasses.h mlir/lib/CAPI/Dialect/Quant.cpp mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h mlir/lib/Dialect/Quant/IR/QuantOps.cpp mlir/lib/Dialect/Quant/IR/QuantTypes.cpp mlir/lib/Dialect/Quant/IR/TypeParser.cpp mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp mlir/lib/Dialect/Tosa/IR/TosaOps.cpp mlir/include/mlir/Dialect/Quant/IR/Quant.h mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
View the diff from clang-format here.
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index c584903f3a..adb737e6e6 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -17,7 +17,6 @@
 
 #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
 
-
 namespace mlir {
 namespace quant {
 
@@ -37,7 +36,8 @@ namespace {
 LogicalResult verifyPerAxisQuantization(Operation *op,
                                         QuantizedType quantizedType,
                                         Type containerType) {
-  auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
+  auto quantizedPerAxisType =
+      dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
   if (!quantizedPerAxisType)
     return success();
 
@@ -54,7 +54,8 @@ LogicalResult verifyPerAxisQuantization(Operation *op,
 
   int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
   if (quantizedDimensionSize != ShapedType::kDynamic &&
-      quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
+      quantizedDimensionSize !=
+          (int64_t)quantizedPerAxisType.getScales().size())
     return op->emitError(
         "quantized dimension size does not match number of scales");
 
@@ -84,8 +85,7 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
   return verifyPerAxisQuantization(op, quantizedType, containerType);
 }
 
-}  // namespace
-
+} // namespace
 
 //===----------------------------------------------------------------------===//
 // Dialect
@@ -101,7 +101,6 @@ void QuantDialect::initialize() {
   detail::addBytecodeInterface(this);
 }
 
-
 //===----------------------------------------------------------------------===//
 // DequantizeCastOp
 //===----------------------------------------------------------------------===//
@@ -130,7 +129,6 @@ QuantizedType DequantizeCastOp::getQuantizedType() {
   return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
 }
 
-
 //===----------------------------------------------------------------------===//
 // QuantizeCastOp
 //===----------------------------------------------------------------------===//
@@ -160,7 +158,6 @@ QuantizedType QuantizeCastOp::getQuantizedType() {
   return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
 }
 
-
 //===----------------------------------------------------------------------===//
 // StorageCastOp
 //===----------------------------------------------------------------------===//
@@ -205,10 +202,8 @@ QuantizedType StorageCastOp::getQuantizedType() {
   return cast<QuantizedType>(resultScalarType);
 }
 
-
 } // namespace quant
 } // namespace mlir
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
-
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index ac01b37a55..a4b6fec8be 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "TypeDetail.h"
 #include "mlir/Dialect/Quant/IR/Quant.h"
-#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
@@ -34,7 +34,7 @@ double getMaxScale(Type expressedType) {
   return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
 }
 
-}  // namespace
+} // namespace
 
 unsigned QuantizedType::getFlags() const {
   return static_cast<ImplType *>(impl)->flags;
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 4adeb9218f..6929d8861e 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -38,11 +38,11 @@ Type getScalarType(Type inputType) {
   return inputType;
 }
 
-// Return the shape of an input value as a list of attributes (static dimensions)
-// and values (dynamic dimensions). If 'input' is a scalar, an empty list is
-// returned. If 'input' is a tensor, its shape is returned.
-SmallVector<OpFoldResult>
-getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
+// Return the shape of an input value as a list of attributes (static
+// dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
+// list is returned. If 'input' is a tensor, its shape is returned.
+SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
+                                                 Location loc, Value input) {
   if (isa<TensorType>(input.getType()))
     return tensor::getMixedSizes(builder, loc, input);
   return {};
@@ -100,16 +100,16 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
 
   // Turn input size into 1D tensor
   auto flatShapeType = shape::getExtentTensorType(context, 1);
-  auto flatInputShape = builder.create<tensor::FromElementsOp>(
-      loc, flatShapeType, inputSize);
+  auto flatInputShape =
+      builder.create<tensor::FromElementsOp>(loc, flatShapeType, inputSize);
 
   // Reshape input tensor into 1D
   auto inputType = cast<UnrankedTensorType>(input.getType());
   auto elementType = inputType.getElementType();
   auto flatInputType =
       RankedTensorType::get({ShapedType::kDynamic}, elementType);
-  auto flatInput = builder.create<tensor::ReshapeOp>(
-      loc, flatInputType, input, flatInputShape);
+  auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
+                                                     flatInputShape);
   return std::make_pair(flatInput, inputShape);
 }
 
@@ -135,11 +135,9 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
 // - inputShape
 //   1D extent tensor containing the shape of the original unranked input.
 //
-std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
-                                                        Location loc,
-                                                        Value input,
-                                                        int64_t axis,
-                                                        int64_t axisSize) {
+std::pair<Value, Value>
+flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
+                                int64_t axis, int64_t axisSize) {
   // Get full tensor shape
   auto *context = builder.getContext();
   auto indexType = builder.getIndexType();
@@ -149,16 +147,20 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
   // Get shape and sizes on left and right of axis
   auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
   auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
-  auto shapeLeft = builder.create<shape::SplitAtOp>(
-      loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
-      .getResult(0);
-  auto sizeLeft = builder.create<shape::NumElementsOp>(
-      loc, indexType, shapeLeft);
-  auto shapeRight = builder.create<shape::SplitAtOp>(
-      loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
-      .getResult(1);
-  auto sizeRight = builder.create<shape::NumElementsOp>(
-      loc, indexType, shapeRight);
+  auto shapeLeft =
+      builder
+          .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
+                                    inputShape, axisValue)
+          .getResult(0);
+  auto sizeLeft =
+      builder.create<shape::NumElementsOp>(loc, indexType, shapeLeft);
+  auto shapeRight =
+      builder
+          .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
+                                    inputShape, axisNextValue)
+          .getResult(1);
+  auto sizeRight =
+      builder.create<shape::NumElementsOp>(loc, indexType, shapeRight);
 
   // Compute flat input shape as a 3-element 1D tensor
   auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
@@ -171,8 +173,8 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
   auto elementType = inputType.getElementType();
   auto flatInputType = RankedTensorType::get(
       {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
-  auto flatInput = builder.create<tensor::ReshapeOp>(
-      loc, flatInputType, input, flatInputShape);
+  auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
+                                                     flatInputShape);
 
   return std::make_pair(flatInput, inputShape);
 }
@@ -190,7 +192,8 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
   auto inputType = cast<RankedTensorType>(input.getType());
   auto elementType = inputType.getElementType();
   auto unrankedType = UnrankedTensorType::get(elementType);
-  return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
+  return builder.create<tensor::ReshapeOp>(loc, unrankedType, input,
+                                           inputShape);
 }
 
 // Create a tensor constant containing all scales in a per-channel quantized
@@ -209,7 +212,8 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc,
   auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
     return builder.getFloatAttr(expressedType, scale);
   });
-  auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType);
+  auto tensorType =
+      RankedTensorType::get({(int64_t)scales.size()}, expressedType);
   auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
   return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
 }
@@ -228,9 +232,8 @@ Value materializePerChannelZeroPoints(
     UniformQuantizedPerAxisType quantizedType) {
   auto zeroPoints = quantizedType.getZeroPoints();
   auto storageType = quantizedType.getStorageType();
-  auto zeroPointAttrs = llvm::map_to_vector(
-      zeroPoints,
-      [&](int64_t zeroPoint) -> Attribute {
+  auto zeroPointAttrs =
+      llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
         return builder.getIntegerAttr(storageType, zeroPoint);
       });
   auto tensorType =
@@ -299,7 +302,7 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
   return builder.create<arith::UIToFPOp>(loc, resultType, input);
 }
 
-// Quantize a scalar or ranked tensor value. The stored value is clamped using 
+// Quantize a scalar or ranked tensor value. The stored value is clamped using
 // the storage bounds encoded in the given quantized type.
 //
 // See function 'convertRanked()' below for a description of the arguments.
@@ -308,8 +311,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
                     Value zeroPoint, QuantizedType quantizedType) {
   // Convert scale to tensor if necessary
   auto inputType = input.getType();
-  scale = getScalarOrTensorConstant(
-      builder, loc, scale, inputType, inputShape);
+  scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
 
   // Scale input
   auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
@@ -322,8 +324,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
                                           inputShape);
 
     // Convert zero point from storage to expressed type
-    zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
-                                      scale.getType(),
+    zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
                                       quantizedType.isSigned());
 
     // Add zero point to stored value
@@ -334,9 +335,9 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
   // Convert stored value to storage type
   auto storageScalarOrTensorType =
       getScalarOrTensorType(quantizedType.getStorageType(), inputType);
-  auto storedValueInt = convertFloatToInteger(
-      builder, loc, storedValueFloat, storageScalarOrTensorType,
-      quantizedType.isSigned());
+  auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
+                                              storageScalarOrTensorType,
+                                              quantizedType.isSigned());
 
   // Clamp stored value it if the storage type is bound
   auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
@@ -352,12 +353,11 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
                       Value zeroPoint, QuantizedType quantizedType) {
   // Convert scale to tensor if necessary
   auto inputType = input.getType();
-  scale = getScalarOrTensorConstant(
-      builder, loc, scale, inputType, inputShape);
+  scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
 
   // Convert stored value to float
-  auto result = convertIntegerToFloat(
-      builder, loc, input, scale.getType(), quantizedType.isSigned());
+  auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
+                                      quantizedType.isSigned());
 
   // Skip unnecessary computations if no zero point is given
   if (!matchPattern(zeroPoint, m_Zero())) {
@@ -366,8 +366,7 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
                                           inputShape);
 
     // Convert zero point from storage to expressed type
-    zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
-                                      scale.getType(),
+    zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
                                       quantizedType.isSigned());
 
     // Subtract zero point to stored value
@@ -501,35 +500,33 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
   auto initShape = tensor::getMixedSizes(builder, loc, input);
   Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
 
-  SmallVector<utils::IteratorType> iteratorTypes(
-      inputRank, utils::IteratorType::parallel);
+  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
+                                                 utils::IteratorType::parallel);
   auto channelAxisAffineMap = AffineMap::get(
       inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
   SmallVector<AffineMap> indexingMaps{
-    builder.getMultiDimIdentityMap(inputRank),
-    channelAxisAffineMap,
-    channelAxisAffineMap,
-    builder.getMultiDimIdentityMap(inputRank)
-  };
-  auto result = builder.create<linalg::GenericOp>(
-      loc,
-      init.getType(),  // resultType
-      ValueRange{input, scales, zeroPoints},  // inputs
-      ValueRange{init},  // outputs
-      indexingMaps,
-      iteratorTypes,
-      [&](OpBuilder& builder, Location loc, ValueRange args) {
-        assert(args.size() == 4);
-        auto input = args[0];
-        auto scale = args[1];
-        auto zeroPoint = args[2];
-
-        auto result = convertRanked(builder, loc, op, input, {}, scale,
-                                    zeroPoint, quantizedType);
-
-        builder.create<linalg::YieldOp>(loc, result);
-      })
-      .getResult(0);
+      builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
+      channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
+  auto result = builder
+                    .create<linalg::GenericOp>(
+                        loc,
+                        init.getType(),                        // resultType
+                        ValueRange{input, scales, zeroPoints}, // inputs
+                        ValueRange{init},                      // outputs
+                        indexingMaps, iteratorTypes,
+                        [&](OpBuilder &builder, Location loc, ValueRange args) {
+                          assert(args.size() == 4);
+                          auto input = args[0];
+                          auto scale = args[1];
+                          auto zeroPoint = args[2];
+
+                          auto result =
+                              convertRanked(builder, loc, op, input, {}, scale,
+                                            zeroPoint, quantizedType);
+
+                          builder.create<linalg::YieldOp>(loc, result);
+                        })
+                    .getResult(0);
 
   return result;
 }
@@ -551,7 +548,7 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
   // Flatten unranked tensor into a 3D ranked tensor if necessary
   bool isUnranked = isa<UnrankedTensorType>(input.getType());
   int64_t channelAxis = quantizedType.getQuantizedDimension();
-  int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
+  int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
   Value inputShape;
   if (isUnranked) {
     std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
@@ -597,7 +594,8 @@ Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
 }
 
 // Lowering pattern for 'quant.dcast'
-struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
+struct DequantizeCastOpConversion
+    : public OpConversionPattern<quant::DequantizeCastOp> {
   using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
 
   LogicalResult
@@ -622,7 +620,8 @@ struct DequantizeCastOpConversion : public OpConversionPattern<quant::Dequantize
 };
 
 // Lowering pattern for 'quant.qcast'
-struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
+struct QuantizeCastOpConversion
+    : public OpConversionPattern<quant::QuantizeCastOp> {
   using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
 
   LogicalResult
@@ -650,12 +649,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
     ConversionTarget target(getContext());
     target.addLegalOp<quant::StorageCastOp>();
     target.addIllegalDialect<quant::QuantDialect>();
-    target.addLegalDialect<
-      arith::ArithDialect,
-      linalg::LinalgDialect,
-      shape::ShapeDialect,
-      tensor::TensorDialect
-    >();
+    target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
+                           shape::ShapeDialect, tensor::TensorDialect>();
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -666,10 +661,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
 } // namespace
 
 void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
-  patterns.add<
-    DequantizeCastOpConversion,
-    QuantizeCastOpConversion
-  >(patterns.getContext());
+  patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
+      patterns.getContext());
 }
 
 } // namespace quant
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 8996eff61a..6191272266 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -36,9 +36,10 @@ class QuantizedTypeConverter : public TypeConverter {
   static Type convertQuantizedType(QuantizedType quantizedType) {
     return quantizedType.getStorageType();
   }
-  
+
   static Type convertTensorType(TensorType tensorType) {
-    if (auto quantizedType = dyn_cast<QuantizedType>(tensorType.getElementType()))
+    if (auto quantizedType =
+            dyn_cast<QuantizedType>(tensorType.getElementType()))
       return tensorType.clone(convertQuantizedType(quantizedType));
     return tensorType;
   }
@@ -50,7 +51,6 @@ class QuantizedTypeConverter : public TypeConverter {
   }
 
 public:
-
   explicit QuantizedTypeConverter() {
     addConversion([](Type type) { return type; });
     addConversion(convertQuantizedType);
@@ -63,7 +63,8 @@ public:
 };
 
 // Conversion pass
-class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
+class StripFuncQuantTypes
+    : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
 
   // Return whether a type is considered legal when occurring in the header of
   // a function or as an operand to a 'return' op.
@@ -74,11 +75,10 @@ class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantT
   }
 
 public:
-
   void runOnOperation() override {
-    
+
     auto moduleOp = cast<ModuleOp>(getOperation());
-    auto* context = &getContext();
+    auto *context = &getContext();
 
     QuantizedTypeConverter typeConverter;
     ConversionTarget target(*context);
@@ -111,4 +111,3 @@ public:
 
 } // namespace quant
 } // namespace mlir
-
diff --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
index fb27640bfd..308ff35e01 100644
--- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 
 using namespace mlir;
 using namespace mlir::quant;

@rafaelubalmw
Copy link
Contributor Author

@jpienaar For sure, Jacques. I'll stay tuned for any downstream issues that this change may cause, and will be happy to address. The most likely impact is related with the change in the location of header files and global symbol naming, which I've modified for consistency with other dialects, so that might affect downstream #includes. Other than that, I've retained the intended semantics of the original dialect types and ops, according to my best interpretation.

@rafaelubalmw rafaelubalmw merged commit 852b648 into llvm:main Sep 26, 2024
5 of 7 checks passed
@rafaelubalmw rafaelubalmw deleted the quant-dialect branch September 26, 2024 18:10
norx1991 added a commit to norx1991/llvm-project that referenced this pull request Sep 26, 2024
@norx1991 norx1991 mentioned this pull request Sep 26, 2024
rupprecht pushed a commit that referenced this pull request Sep 26, 2024
hanhanW added a commit to iree-org/llvm-project that referenced this pull request Sep 27, 2024
hanhanW added a commit to iree-org/llvm-project that referenced this pull request Sep 27, 2024
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Sep 27, 2024
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Sep 27, 2024
bjacob pushed a commit to iree-org/llvm-project that referenced this pull request Sep 30, 2024
puja2196 pushed a commit to puja2196/LLVM-tutorial that referenced this pull request Sep 30, 2024
bjacob added a commit to iree-org/iree that referenced this pull request Sep 30, 2024
Cherry-picks:
1. Cherry-picking llvm/llvm-project#110518

Carrying two local reverts:
1. Revert llvm/llvm-project#100667)
- As noted by @hanhanW on #18619,
that PR "breaks the stablehlo build. We need to wait stablehlo bumping
LLVM ahead of it and fix the issue. Then we can bump stablehlo and drop
the local commit together."
2. Revert llvm/llvm-project#110170)
   - That is just the Bazel change accompanying 1.

Signed-off-by: Benoit Jacob <[email protected]>
bjacob pushed a commit to iree-org/llvm-project that referenced this pull request Oct 1, 2024
bjacob pushed a commit to iree-org/llvm-project that referenced this pull request Oct 1, 2024
bjacob added a commit to iree-org/iree that referenced this pull request Oct 1, 2024
Existing reverts carried over:
1. Revert llvm/llvm-project#100667)
- As noted by @hanhanW on #18619,
that PR "breaks the stablehlo build. We need to wait stablehlo bumping
LLVM ahead of it and fix the issue. Then we can bump stablehlo and drop
the local commit together."
2. Revert llvm/llvm-project#110170)
   - That is just the Bazel change accompanying 1.

Signed-off-by: Benoit Jacob <[email protected]>
puja2196 pushed a commit to puja2196/LLVM-tutorial that referenced this pull request Oct 2, 2024
abhigunj pushed a commit to openxla/stablehlo that referenced this pull request Oct 3, 2024
The recent upstream
[change](llvm/llvm-project#100667) have
introduced quantization checks that are
already present in the StableHLO core library. This commit removes these
duplicate
checks to avoid redundancy and potential inconsistencies.


|Checks proposed to be removed| StableHLO Code | Upstream MLIR | 
|-|-|-|
| `channel-axis >= 0`|
[cs](https://github.com/openxla/stablehlo/blob/1c0547f391dff5ac71d36dc20a916260afa78c61/stablehlo/dialect/Base.cpp#L795)
|
[cs](https://github.com/llvm/llvm-project/blob/96f37ae45310885e09195be09d9c05e1c1dff86b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp#L399)
|
| scale within smallest and largest finite numbers determined by
`expressed_type`|
[cs](https://github.com/openxla/stablehlo/blob/1c0547f391dff5ac71d36dc20a916260afa78c61/stablehlo/dialect/Base.cpp#L765)
|
[cs1](https://github.com/llvm/llvm-project/blob/96f37ae45310885e09195be09d9c05e1c1dff86b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp#L327)
[cs2](https://github.com/llvm/llvm-project/blob/96f37ae45310885e09195be09d9c05e1c1dff86b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp#L393C9-L393C45)
|


Note that StableHLO has checks like `quantization_dimension <
rank(self)` and
`dim(self, quantization_dimension) = size(scales)` implemented at
[cs](https://github.com/openxla/stablehlo/blob/1c0547f391dff5ac71d36dc20a916260afa78c61/stablehlo/dialect/Base.cpp#L795).
In upstream MLIR similar checks
[cs](https://github.com/llvm/llvm-project/blob/96f37ae45310885e09195be09d9c05e1c1dff86b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp#L51)
are encoded as part of
[dcast](https://github.com/llvm/llvm-project/blob/96f37ae45310885e09195be09d9c05e1c1dff86b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp#L110)
and
[qcast](https://github.com/llvm/llvm-project/blob/96f37ae45310885e09195be09d9c05e1c1dff86b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp#L139)
ops and hence cannot be claimed as duplicate.

related upstream clean-up
llvm/llvm-project#110604
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants