From 225395d575108e719367759f3767fff68d511a97 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 23 Feb 2024 14:46:33 +0000 Subject: [PATCH] [mlir][linalg] `LinalgOp`: Disallow mixed tensor/buffer semantics Related discussion: https://github.com/llvm/llvm-project/pull/73908/files#r1414913030. This change fixes #73547. --- .../Dialect/Linalg/IR/LinalgInterfaces.cpp | 5 ++ mlir/test/Dialect/Linalg/canonicalize.mlir | 55 +++++-------------- .../Linalg/fusion-elementwise-ops.mlir | 40 -------------- mlir/test/Dialect/Linalg/invalid.mlir | 10 ++++ 4 files changed, 29 insertions(+), 81 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 7eed7928456d5..3627ff6617eda 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -1041,6 +1041,11 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) { LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast(op); + // Mixed tensor/buffer operands are not allowed. + if (!linalgOp.hasPureTensorSemantics() && + !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0) + return op->emitOpError("expected to have pure tensor or buffer semantics"); + // Before checking indexing maps, we need to make sure the attributes // referenced by it are valid. if (linalgOp.hasDynamicIndexingMaps()) diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 7adde3117deea..206d7e9f1ce8d 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -102,17 +102,16 @@ func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : ten // ----- // CHECK-LABEL: func @linalg_effects( -// CHECK-SAME: %[[A:[a-z0-9]*]]: tensor -// CHECK-SAME: %[[B:[a-z0-9]*]]: memref -// CHECK-SAME: %[[C:[a-z0-9]*]]: tensor -func.func @linalg_effects(%a : tensor, %b : memref, %c : tensor) { +func.func @linalg_effects( + %a : tensor, %b : tensor, %c : tensor, + %d : memref, %e : memref, %f : memref) { // CHECK-NOT: %{{.*}} = linalg.matmul - %t = linalg.matmul ins(%a, %b : tensor, memref) + %t = linalg.matmul ins(%a, %b : tensor, tensor) outs(%c : tensor) -> tensor // CHECK: linalg.matmul - linalg.matmul ins(%a, %c : tensor, tensor) - outs(%b : memref) + linalg.matmul ins(%d, %e : memref, memref) + outs(%f : memref) return } @@ -889,11 +888,11 @@ func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor) -> // ----- #map = affine_map<(d0) -> (d0)> -func.func @identity_mixed(%arg0 : tensor, %arg1: memref) { +func.func @identity_buffer(%arg0 : memref, %arg1: memref) { linalg.generic { indexing_maps = [#map, #map], iterator_types = ["parallel"] - } ins(%arg0 : tensor) + } ins(%arg0 : memref) outs(%arg1 : memref) { ^bb0(%arg2 : f32, %arg3 : f32): linalg.yield %arg2 : f32 @@ -901,14 +900,13 @@ func.func @identity_mixed(%arg0 : tensor, %arg1: memref) { return } -// There was a crash in EraseIdentityGenericOp for generic with mixed semantics. -// For now, check generic remained unchanged. -// CHECK-LABEL: func @identity_mixed -// CHECK-SAME: (%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: memref) +// Do not erase ops with buffer semantics. +// CHECK-LABEL: func @identity_buffer +// CHECK-SAME: (%[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref) // CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [#map, #map], // CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: } ins(%[[ARG1]] : tensor) +// CHECK-SAME: } ins(%[[ARG1]] : memref) // CHECK-SAME: outs(%[[ARG2]] : memref) { // ----- @@ -916,12 +914,12 @@ func.func @identity_mixed(%arg0 : tensor, %arg1: memref) { // Just make sure that we don't crash. // CHECK-LABEL: func @dedeplicate_regression_test -func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) { +func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: tensor<4xf32>) { %36 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} - ins(%1, %1 : memref<4xf32>, memref<4xf32>) + ins(%1, %1 : tensor<4xf32>, tensor<4xf32>) outs(%0 : tensor<4xf32>) { ^bb0(%in: f32, %in_24: f32, %out: f32): linalg.yield %in : f32 @@ -937,31 +935,6 @@ func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) { // ----- -#map = affine_map<(d0) -> (d0)> -func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref) { - %0 = tensor.cast %arg0 : tensor<5xf32> to tensor - linalg.generic { - indexing_maps = [#map, #map], - iterator_types = ["parallel"] - } ins(%0 : tensor) - outs(%arg1 : memref) { - ^bb0(%arg2 : f32, %arg3 : f32): - linalg.yield %arg2 : f32 - } - return -} - -// We need a mixed linalg as a bridge between tensor and memref worlds. -// CHECK-LABEL: func @cast_producer_mixed -// CHECK-SAME: (%[[ARG1:.*]]: tensor<5xf32>, %[[ARG2:.*]]: memref) -// CHECK: linalg.generic { -// CHECK-SAME: indexing_maps = [#map, #map], -// CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: } ins(%[[ARG1]] : tensor<5xf32>) -// CHECK-SAME: outs(%[[ARG2]] : memref) { - -// ----- - // CHECK-LABEL: dead_softmax func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { %0 = tensor.empty() : tensor<16x64x256xf32> diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 9d8421cbab49d..15a4f6cdd3bbe 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1110,43 +1110,3 @@ module { // CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]] // CHECK: linalg.yield %[[T3]] : f32 // CHECK: return %[[GENERIC]] - -// ----- - -// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)> -#map0 = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @mixed_fusion -func.func @mixed_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor, %arg8 : memref) -{ - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = tensor.dim %arg0, %c0 : tensor - %1 = tensor.dim %arg0, %c1 : tensor - %2 = tensor.empty(%0, %1) : tensor - %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) - outs(%2 : tensor) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): - %4 = arith.addf %arg3, %arg4 : f32 - linalg.yield %4 : f32 - } -> tensor - // CHECK: linalg.generic { - // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}} - linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} - ins(%3, %arg2 : tensor, tensor) - outs(%arg8 : memref) { - // CHECK: ^{{[a-zA-Z0-9_]*}} - // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]] - // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]] - // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]] - ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): - // CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG0]], [[ARG1]] - // CHECK-NOT: linalg.yield - // CHECK: arith.mulf [[T1]], [[ARG2]] - // CHECK: linalg.yield - %5 = arith.mulf %arg5, %arg6 : f32 - linalg.yield %5 : f32 - } - return -} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 916c04f33e9c6..44c81c31ace0f 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -770,3 +770,13 @@ func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>, -> tensor<8x8xf32> return %res : tensor<8x8xf32> } + +// ----- + +func.func @mixed_semantics(%a: tensor, %b: tensor, %c: memref) { + // expected-error @+1 {{expected to have pure tensor or buffer semantics}} + linalg.matmul ins(%a, %b: tensor, tensor) + outs(%c: memref) + return +} +