-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[mlir][linalg] Add quantized conv2d operator with FCHW,NCHW order #107740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch adds a quantized version of the `linalg.conv2d_nchw_fchw` Op. This is the "channel-first" ordering typically used by PyTorch and others.
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Felix Schneider (ubfx) ChangesThis patch adds a quantized version of the Full diff: https://github.com/llvm/llvm-project/pull/107740.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 46b3ec0f60ebfa..4648a9133953af 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -3114,6 +3114,143 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: KZp
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: conv_2d_nchw_fchw_q
+ cpp_class_name: Conv2DNchwFchwQOp
+ doc: |-
+ Performs 2-D convolution with zero point offsets.
+
+ Layout:
+ * Input: NCHW.
+ * Kernel: FCHW.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. This includes the zero
+ point offsets common to quantized operations.
+ implements:
+ - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
+ s1, s2 * s3 + s4 * s5, s6 * s7 + s8 * s9)>
+ - !LinalgOperandDefConfig
+ name: K
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10,
+ s1, s4, s8)>
+ - !LinalgOperandDefConfig
+ name: IZp
+ kind: scalar
+ type_var: I32
+ - !LinalgOperandDefConfig
+ name: KZp
+ kind: scalar
+ type_var: I32
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
+ s10, s2, s6)>
+ - !LinalgOperandDefConfig
+ name: strides
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s3, s7)>
+ default_indices:
+ - 1
+ - 1
+ - !LinalgOperandDefConfig
+ name: dilations
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s5, s9)>
+ default_indices:
+ - 1
+ - 1
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+ s9, s10] -> (d0, d4, d2 * s3 + d5 * s5, d3 * s7 + d6 * s9)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+ s9, s10] -> (d1, d4, d5, d6)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+ s9, s10] -> ()>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+ s9, s10] -> ()>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+ s9, s10] -> (d0, d1, d2, d3)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ - reduction
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: O
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: sub
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: IZp
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: sub
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: K
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: KZp
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nchw_fchw
cpp_class_name: Conv2DNchwFchwOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 67bde8f736ef46..67bae10ad16ca2 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -875,6 +875,34 @@ def conv_2d_nhwc_fhwc_q(
- TypeFn.cast_signed(U, IZp)
) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp))
+@linalg_structured_op
+def conv_2d_nchw_fchw_q(
+ I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+ K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
+ IZp=ScalarDef(I32),
+ KZp=ScalarDef(I32),
+ O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs 2-D convolution with zero point offsets.
+
+ Layout:
+ * Input: NCHW.
+ * Kernel: FCHW.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. This includes the zero
+ point offsets common to quantized operations.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
+ O[D.n, D.f, D.oh, D.ow] += (
+ TypeFn.cast_signed(
+ U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+ )
+ - TypeFn.cast_signed(U, IZp)
+ ) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp))
@linalg_structured_op
def conv_2d_nchw_fchw(
|
✅ With the latest revision this PR passed the Python code formatter. |
For context: In torch-mlir, we currently have to use an additional transposition on weights and inits for all quantized convolutions. This is because we have no fitting quantized convolution op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs some round trip tests, but apart from that looks ok to me.
ping |
I've upstreamed the necessary quantized linalg Op with the "channel-first" ordering used by torch (llvm/llvm-project#107740) for 2d convolution. This patch changes the lowering for the quantized 2d case of `aten.convolution` accordingly, which saves three transpositions per convolution (input, weights, result) and therefore removes the requirement to try to optimize these away in downstream passes.
I've upstreamed the necessary quantized linalg Op with the "channel-first" ordering used by torch (llvm/llvm-project#107740) for 2d convolution. This patch changes the lowering for the quantized 2d case of `aten.convolution` accordingly, which saves three transpositions per convolution (input, weights, result) and therefore removes the requirement to try to optimize these away in downstream passes.
…#3807) I've upstreamed the necessary quantized linalg Op with the "channel-first" ordering used by torch (llvm/llvm-project#107740) for 2d convolution. This patch changes the lowering for the quantized 2d case of `aten.convolution` accordingly, which saves three transpositions per convolution (input, weights, result) and therefore removes the requirement to try to optimize these away in downstream passes.
@ubfx Just wondering, why the memory layout is expressed right in the name of linalg operation? |
Yes I think the current solution (for all the conv ops in the dialect) isn't ideal and there have been a couple of suggestions and PRs to improve it, e.g. by expressing layout in attributes . Unfortunately, linalg seems to be where all of the individual interests collide which has lead to a certain degree of inertia, so I'm not sure whether fixing this is still actively worked on. |
This patch adds a quantized version of the
linalg.conv2d_nchw_fchw
Op. This is the "channel-first" ordering typically used by PyTorch and others.