Skip to content

[mlir][linalg] Add quantized conv2d operator with FCHW,NCHW order #107740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 19, 2024

Conversation

ubfx
Copy link
Member

@ubfx ubfx commented Sep 8, 2024

This patch adds a quantized version of the linalg.conv2d_nchw_fchw Op. This is the "channel-first" ordering typically used by PyTorch and others.

This patch adds a quantized version of the `linalg.conv2d_nchw_fchw` Op.
This is the "channel-first" ordering typically used by PyTorch and others.
@llvmbot
Copy link
Member

llvmbot commented Sep 8, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Felix Schneider (ubfx)

Changes

This patch adds a quantized version of the linalg.conv2d_nchw_fchw Op. This is the "channel-first" ordering typically used by PyTorch and others.


Full diff: https://github.com/llvm/llvm-project/pull/107740.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+137)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+28)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 46b3ec0f60ebfa..4648a9133953af 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -3114,6 +3114,143 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: KZp
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_2d_nchw_fchw_q
+  cpp_class_name: Conv2DNchwFchwQOp
+  doc: |-
+    Performs 2-D convolution with zero point offsets.
+
+    Layout:
+      * Input: NCHW.
+      * Kernel: FCHW.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
+      s1, s2 * s3 + s4 * s5, s6 * s7 + s8 * s9)>
+  - !LinalgOperandDefConfig
+    name: K
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10,
+      s1, s4, s8)>
+  - !LinalgOperandDefConfig
+    name: IZp
+    kind: scalar
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: KZp
+    kind: scalar
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
+      s10, s2, s6)>
+  - !LinalgOperandDefConfig
+    name: strides
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+      (s3, s7)>
+    default_indices:
+    - 1
+    - 1
+  - !LinalgOperandDefConfig
+    name: dilations
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+      (s5, s9)>
+    default_indices:
+    - 1
+    - 1
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10] -> (d0, d4, d2 * s3 + d5 * s5, d3 * s7 + d6 * s9)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10] -> (d1, d4, d5, d6)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10] -> (d0, d1, d2, d3)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: I
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: IZp
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: K
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: KZp
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d_nchw_fchw
   cpp_class_name: Conv2DNchwFchwOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 67bde8f736ef46..67bae10ad16ca2 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -875,6 +875,34 @@ def conv_2d_nhwc_fhwc_q(
         - TypeFn.cast_signed(U, IZp)
     ) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp))
 
+@linalg_structured_op
+def conv_2d_nchw_fchw_q(
+    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D convolution with zero point offsets.
+
+    Layout:
+      * Input: NCHW.
+      * Kernel: FCHW.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.f, D.oh, D.ow] += (
+        TypeFn.cast_signed(
+            U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+        )
+        - TypeFn.cast_signed(U, IZp)
+    ) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp))
 
 @linalg_structured_op
 def conv_2d_nchw_fchw(

Copy link

github-actions bot commented Sep 8, 2024

✅ With the latest revision this PR passed the Python code formatter.

@ubfx ubfx requested a review from rsuderman September 16, 2024 07:03
@ubfx
Copy link
Member Author

ubfx commented Sep 19, 2024

For context: In torch-mlir, we currently have to use an additional transposition on weights and inits for all quantized convolutions. This is because we have no fitting quantized convolution op

https://github.com/llvm/torch-mlir/blob/5ce48dfacd971e5075786731bac2152ae855cab4/lib/Conversion/TorchToLinalg/Linear.cpp#L1165-L1167

@ubfx ubfx requested review from stellaraccident, ftynse and makslevental and removed request for stellaraccident, ftynse and makslevental September 19, 2024 13:40
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs some round trip tests, but apart from that looks ok to me.

@ubfx
Copy link
Member Author

ubfx commented Oct 12, 2024

ping

@ubfx ubfx merged commit 02bf3b5 into llvm:main Oct 19, 2024
8 checks passed
@ubfx ubfx deleted the linalg-add-conv2d-nchw-fchw-q branch October 19, 2024 16:25
ubfx added a commit to ubfx/torch-mlir that referenced this pull request Oct 20, 2024
I've upstreamed the necessary quantized linalg Op with the "channel-first"
ordering used by torch (llvm/llvm-project#107740)
for 2d convolution.

This patch changes the lowering for the quantized 2d case of `aten.convolution`
accordingly, which saves three transpositions per convolution (input,
weights, result) and therefore removes the requirement to try to optimize
these away in downstream passes.
ubfx added a commit to ubfx/torch-mlir that referenced this pull request Oct 22, 2024
I've upstreamed the necessary quantized linalg Op with the "channel-first"
ordering used by torch (llvm/llvm-project#107740)
for 2d convolution.

This patch changes the lowering for the quantized 2d case of `aten.convolution`
accordingly, which saves three transpositions per convolution (input,
weights, result) and therefore removes the requirement to try to optimize
these away in downstream passes.
ubfx added a commit to llvm/torch-mlir that referenced this pull request Oct 22, 2024
…#3807)

I've upstreamed the necessary quantized linalg Op with the
"channel-first" ordering used by torch
(llvm/llvm-project#107740) for 2d convolution.

This patch changes the lowering for the quantized 2d case of
`aten.convolution` accordingly, which saves three transpositions per
convolution (input, weights, result) and therefore removes the
requirement to try to optimize these away in downstream passes.
@EgorDuplensky
Copy link

@ubfx Just wondering, why the memory layout is expressed right in the name of linalg operation?
Shouldn't it be at least an attribute? Or do we need to express it at all? Can't we propagate layouts in scope of some extra passes?

@ubfx
Copy link
Member Author

ubfx commented Jul 21, 2025

@ubfx Just wondering, why the memory layout is expressed right in the name of linalg operation? Shouldn't it be at least an attribute? Or do we need to express it at all? Can't we propagate layouts in scope of some extra passes?

Yes I think the current solution (for all the conv ops in the dialect) isn't ideal and there have been a couple of suggestions and PRs to improve it, e.g. by expressing layout in attributes . Unfortunately, linalg seems to be where all of the individual interests collide which has lead to a certain degree of inertia, so I'm not sure whether fixing this is still actively worked on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants