|
| 1 | +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s |
| 2 | + |
| 3 | +// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor |
| 4 | +func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> { |
| 5 | + %init = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32> |
| 6 | + // CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwc |
| 7 | + // CHECK-SAME: {strides = dense<2> : vector<2xi64>} |
| 8 | + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>) |
| 9 | + // CHECK-SAME: outs(%{{.+}} : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> |
| 10 | + %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>} |
| 11 | + ins(%input, %filter: tensor<1x113x113x96xf32>, tensor<3x3x96xf32>) |
| 12 | + outs(%init: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> |
| 13 | + return %0: tensor<1x56x56x96xf32> |
| 14 | +} |
| 15 | + |
| 16 | +// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwc_memref |
| 17 | +func @depthwise_conv_2d_input_nhwc_filter_hwc_memref(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { |
| 18 | + // CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc |
| 19 | + // CHECK-SAME: {strides = dense<2> : vector<2xi64>} |
| 20 | + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x113x113x96xf32>, memref<3x3x96xf32>) |
| 21 | + // CHECK-SAME: outs(%{{.+}} : memref<1x56x56x96xf32>) |
| 22 | + linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>} |
| 23 | + ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) |
| 24 | + outs(%output: memref<1x56x56x96xf32>) |
| 25 | + return |
| 26 | +} |
| 27 | + |
| 28 | +// ----- |
| 29 | + |
| 30 | +func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { |
| 31 | + // expected-error @+1 {{missing indexing map required attribute 'strides'}} |
| 32 | + linalg.depthwise_conv_2d_input_nhwc_filter_hwc |
| 33 | + ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) |
| 34 | + outs(%output: memref<1x56x56x96xf32>) |
| 35 | + return |
| 36 | +} |
| 37 | + |
| 38 | +// ----- |
| 39 | + |
| 40 | +func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { |
| 41 | + // expected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}} |
| 42 | + linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2.0> : vector<2xf32>} |
| 43 | + ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) |
| 44 | + outs(%output: memref<1x56x56x96xf32>) |
| 45 | + return |
| 46 | +} |
| 47 | + |
| 48 | +// ----- |
| 49 | + |
| 50 | +func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { |
| 51 | + // expected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}} |
| 52 | + linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<3xi64> } |
| 53 | + ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) |
| 54 | + outs(%output: memref<1x56x56x96xf32>) |
| 55 | + return |
| 56 | +} |
0 commit comments