diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index bf2f26de26e9e..4e3ef937d7d48 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -4024,6 +4024,145 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: K --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: conv_3d_ncdhw_fcdhw_q + cpp_class_name: Conv3DNcdhwFcdhwQOp + doc: |- + Performs 3-D convolution with zero point offsets. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, + s13, s14] -> (s0, s1, s2 * s3 + s4 * s5, s6 * s7 + s8 * s9, s10 * s11 + s12 + * s13)> + - !LinalgOperandDefConfig + name: K + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, + s13, s14] -> (s14, s1, s4, s8, s12)> + - !LinalgOperandDefConfig + name: IZp + kind: scalar + type_var: I32 + - !LinalgOperandDefConfig + name: KZp + kind: scalar + type_var: I32 + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, + s13, s14] -> (s0, s14, s2, s6, s10)> + - !LinalgOperandDefConfig + name: strides + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + s12, s13, s14] -> (s3, s7, s11)> + default_indices: + - 1 + - 1 + - 1 + - !LinalgOperandDefConfig + name: dilations + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + s12, s13, s14] -> (s5, s9, s13)> + default_indices: + - 1 + - 1 + - 1 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6, + s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d8, d1 * s3 + d5 * s5, d2 * s7 + + d6 * s9, d3 * s11 + d7 * s13)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6, + s7, s8, s9, s10, s11, s12, s13, s14] -> (d4, d8, d5, d6, d7)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6, + s7, s8, s9, s10, s11, s12, s13, s14] -> ()> + - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6, + s7, s8, s9, s10, s11, s12, s13, s14] -> ()> + - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6, + s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d4, d1, d2, d3)> + iterator_types: + - parallel + - parallel + - parallel + - parallel + - parallel + - reduction + - reduction + - reduction + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: mul + operands: + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: sub + operands: + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: I + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: IZp + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: sub + operands: + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: K + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: KZp +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_1d_nwc_wc cpp_class_name: DepthwiseConv1DNwcWcOp diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index b45fecd0ee145..8a733f5ce22eb 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -1127,6 +1127,49 @@ def conv_3d_ncdhw_fcdhw( ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw]) +@linalg_structured_op +def conv_3d_ncdhw_fcdhw_q( + I=TensorDef( + T1, + S.N, + S.C, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + ), + K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution with zero point offsets. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.f, D.od, D.oh, D.ow] += ( + TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) + - TypeFn.cast_signed(U, IZp) + ) * ( + TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw]) + - TypeFn.cast_signed(U, KZp) + ) + + @linalg_structured_op def depthwise_conv_1d_nwc_wc( I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 1b8969bd11559..6e5adf007f58d 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -694,3 +694,18 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt: // CHECK-LABEL: func @conv2d_channel_first_q_promote( // CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8) // CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32> + +// ----- + +func.func @conv3d_channel_first_q(%img: tensor<1x27x49x48x47xi8>, %filt: tensor<28x27x3x4x5xi8>, %a: i32, %b: i32) -> tensor<1x28x47x45x43xi32> { + %init = arith.constant dense<0> : tensor<1x28x47x45x43xi32> + %1 = linalg.conv_3d_ncdhw_fcdhw_q {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins(%img, %filt, %a, %b : tensor<1x27x49x48x47xi8>, tensor<28x27x3x4x5xi8>, i32, i32) + outs(%init : tensor<1x28x47x45x43xi32>) -> tensor<1x28x47x45x43xi32> + return %1 : tensor<1x28x47x45x43xi32> +} + +// CHECK-LABEL: func @conv3d_channel_first_q( +// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<1x27x49x48x47xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<28x27x3x4x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i32, %[[arg3:[a-zA-z0-9]*]]: i32) +// CHECK: linalg.conv_3d_ncdhw_fcdhw_q {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<1x27x49x48x47xi8>, tensor<28x27x3x4x5xi8>, i32, i32) outs(%{{.*}} : tensor<1x28x47x45x43xi32>) -> tensor<1x28x47x45x43xi32>