Skip to content

Commit 068bf9e

Browse files
committed
[mlir][linalg] Define a depthwise 2-D convolution op
This commit defines linalg.depthwise_conv_2d_nhwc for depthwise 2-D convolution with NHWC input/output data format. This op right now only support channel multiplier == 1, which is the most common case. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D94966
1 parent 4c640e4 commit 068bf9e

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,36 @@ def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N,
9191
O(n, f, d, h, w) = std_addf<kd, kh, kw>(std_mulf(
9292
I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
9393
}
94+
95+
ods_def<DepthwiseConvNHWCOp>:
96+
def depthwise_conv_2d_input_nhwc_filter_hwc
97+
(I: f32(N, IH, IW, C), K: f32(KH, KW, C))
98+
-> (O: f32(N, OH, OW, C))
99+
attr(strides: 2xi64)
100+
"""A depth-wise 2-D convolution operation.
101+
102+
This operation performs depth-wise 2-D convolution over an input `I` and filter
103+
`F` and generates output `O` using the following computation:
104+
105+
```
106+
O(n, oh, ow, c) = std_addf<kh, kw>(std_mulf(
107+
I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), K(kh, kw, c)))
108+
```
109+
110+
where
111+
112+
* `I` is a 4-D tensor with shape `(N, IH, IW, C)`.
113+
* `F` is a 3-D tensor with shape `(KH, KW, C)`.
114+
* `O` is a 4-D tensor with shape `(N, OH, OW, C)`.
115+
* `strides` is a 2-element vector attribute for window strides along the
116+
height/width dimension.
117+
118+
The indexing maps for these three tensors contain 6 dimensions, following the
119+
order of (`N`, `OH`, `OW`, `C`, `KH`, `KW`).
120+
121+
Note: this op only supports channel multiplier == 1.
122+
"""
123+
{
124+
O(n, oh, ow, c) = std_addf<kh, kw>(std_mulf(
125+
I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), K(kh, kw, c)));
126+
}

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,29 @@ func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C:
7373
// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
7474
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
7575
// CHECK-NEXT: -> tensor<16x32xf32>
76+
77+
// -----
78+
79+
func @depthwise_conv_2d_input_nhwc_filter_hwc(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
80+
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
81+
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
82+
outs(%output: memref<1x56x56x96xf32>)
83+
return
84+
}
85+
86+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d3)>
87+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
88+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
89+
90+
// CHECK: func @depthwise_conv_2d_input_nhwc_filter_hwc
91+
92+
// CHECK: linalg.generic
93+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
94+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
95+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x113x113x96xf32>, memref<3x3x96xf32>)
96+
// CHECK-SAME: outs(%{{.+}} : memref<1x56x56x96xf32>)
97+
98+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
99+
// CHECK-NEXT: %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
100+
// CHECK-NEXT: %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
101+
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
2+
3+
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor
4+
func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> {
5+
%init = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>
6+
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwc
7+
// CHECK-SAME: {strides = dense<2> : vector<2xi64>}
8+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
9+
// CHECK-SAME: outs(%{{.+}} : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
10+
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
11+
ins(%input, %filter: tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
12+
outs(%init: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
13+
return %0: tensor<1x56x56x96xf32>
14+
}
15+
16+
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwc_memref
17+
func @depthwise_conv_2d_input_nhwc_filter_hwc_memref(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
18+
// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc
19+
// CHECK-SAME: {strides = dense<2> : vector<2xi64>}
20+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x113x113x96xf32>, memref<3x3x96xf32>)
21+
// CHECK-SAME: outs(%{{.+}} : memref<1x56x56x96xf32>)
22+
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
23+
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
24+
outs(%output: memref<1x56x56x96xf32>)
25+
return
26+
}
27+
28+
// -----
29+
30+
func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
31+
// expected-error @+1 {{missing indexing map required attribute 'strides'}}
32+
linalg.depthwise_conv_2d_input_nhwc_filter_hwc
33+
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
34+
outs(%output: memref<1x56x56x96xf32>)
35+
return
36+
}
37+
38+
// -----
39+
40+
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
41+
// expected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}}
42+
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2.0> : vector<2xf32>}
43+
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
44+
outs(%output: memref<1x56x56x96xf32>)
45+
return
46+
}
47+
48+
// -----
49+
50+
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
51+
// expected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}}
52+
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<3xi64> }
53+
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
54+
outs(%output: memref<1x56x56x96xf32>)
55+
return
56+
}

0 commit comments

Comments
 (0)