diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 21261478f0648..902adae6feeb1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -758,19 +758,28 @@ FailureOr ContractionOpToDotLowering::matchAndRewriteMaskableOp( Value res = rewriter.create(loc, dstType, rewriter.getZeroAttr(dstType)); bool isInt = isa(dstType.getElementType()); + llvm::SmallVector extractedCols; + extractedCols.reserve(dstColumns); for (unsigned r = 0; r < dstRows; ++r) { - Value a = rewriter.create(op.getLoc(), lhs, r); + Value rowLhs = rewriter.create(op.getLoc(), lhs, r); for (unsigned c = 0; c < dstColumns; ++c) { - Value b = rank == 1 - ? rhs - : rewriter.create(op.getLoc(), rhs, c); - Value m = createMul(op.getLoc(), a, b, isInt, rewriter); - Value reduced = rewriter.create( - op.getLoc(), vector::CombiningKind::ADD, m); + // Extract each respective row and column of the LHS and RHS once to + // avoid having duplicate SSA values pointing to the same rows/columns. + if (r == 0) { + Value colRhs = + rank == 1 ? rhs + : rewriter.create(op.getLoc(), rhs, c); + extractedCols.push_back(colRhs); + } + Value extractedColRhs = extractedCols[c]; + Value product = + createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter); + Value sum = rewriter.create( + op.getLoc(), vector::CombiningKind::ADD, product); SmallVector pos = rank == 1 ? SmallVector{r} : SmallVector{r, c}; - res = rewriter.create(op.getLoc(), reduced, res, pos); + res = rewriter.create(op.getLoc(), sum, res, pos); } } if (auto acc = op.getAcc()) diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir index 0ba185bb84760..739796099f795 100644 --- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir @@ -151,43 +151,56 @@ func.func @extract_contract3(%arg0: vector<3xf32>, iterator_types = ["parallel", "parallel", "reduction"] } -// CHECK-LABEL: func @extract_contract4 -// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>, -// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32> -// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> -// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32> -// CHECK: %[[T10:.*]] = vector.reduction , %[[T9]] : vector<2xf32> into f32 -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32> +// CHECK-LABEL: func @contract_to_dot_matmat +// CHECK-SAME: %[[LHS:.*0]]: vector<2x2xf32>, +// CHECK-SAME: %[[RHS:.*1]]: vector<2x2xf32>, +// CHECK-SAME: %[[OUT:.*2]]: vector<2x2xf32> // -// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32> -// CHECK: %[[T20:.*]] = vector.reduction , %[[T19]] : vector<2xf32> into f32 -// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32> +// The `vector.contract` to dot lowering will 'unroll' a matrix-matrix +// multiplication into individual dot products betweem rows of the LHS with columns +// of the RHS. In the following test we expect 4 extract-dotproduct-insert sequences of +// ops that correspond to the 4 dot products resulting from unrolling a matmul between +// two matrices of size (2, 2). // -// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32> -// CHECK: %[[T33:.*]] = vector.reduction , %[[T32]] : vector<2xf32> into f32 -// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> // -// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32> -// CHECK: %[[T42:.*]] = vector.reduction , %[[T41]] : vector<2xf32> into f32 -// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32> +// First, The RHS will be transposed to make it easier to extract individual columns +// using vector.extract. // -// CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32> -// CHECK: return %[[T52]] : vector<2x2xf32> +// CHECK: %[[RHS_T:.*]] = vector.transpose %[[RHS]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> +// +// Next, we expect 4 sequences of extracting rows of the RHS, LHS, performing a dot +// product and then inserting it into the result. +// +// CHECK: %[[LHS0:.*]] = vector.extract %[[LHS]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[RHS_T0:.*]] = vector.extract %[[RHS_T]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[PROD0:.*]] = arith.mulf %[[LHS0]], %[[RHS_T0]] : vector<2xf32> +// CHECK: %[[SUM0:.*]] = vector.reduction , %[[PROD0]] : vector<2xf32> into f32 +// CHECK: %[[RES0:.*]] = vector.insert %[[SUM0]], %[[INIT]] [0, 0] : f32 into vector<2x2xf32> +// +// CHECK: %[[RHS_T1:.*]] = vector.extract %[[RHS_T]][1] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[PROD1:.*]] = arith.mulf %[[LHS0]], %[[RHS_T1]] : vector<2xf32> +// CHECK: %[[SUM1:.*]] = vector.reduction , %[[PROD1]] : vector<2xf32> into f32 +// CHECK: %[[RES1:.*]] = vector.insert %[[SUM1]], %[[RES0]] [0, 1] : f32 into vector<2x2xf32> +// +// CHECK: %[[LHS1:.*]] = vector.extract %[[LHS]][1] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[PROD2:.*]] = arith.mulf %[[LHS1]], %[[RHS_T0]] : vector<2xf32> +// CHECK: %[[SUM2:.*]] = vector.reduction , %[[PROD2]] : vector<2xf32> into f32 +// CHECK: %[[RES2:.*]] = vector.insert %[[SUM2]], %[[RES1]] [1, 0] : f32 into vector<2x2xf32> +// +// CHECK: %[[PROD3:.*]] = arith.mulf %[[LHS1]], %[[RHS_T1]] : vector<2xf32> +// CHECK: %[[SUM3:.*]] = vector.reduction , %[[PROD3]] : vector<2xf32> into f32 +// CHECK: %[[RES3:.*]] = vector.insert %[[SUM3]], %[[RES2]] [1, 1] : f32 into vector<2x2xf32> +// +// CHECK: %[[RES:.*]] = arith.addf %[[RES3]], %[[OUT]] : vector<2x2xf32> +// CHECK: return %[[RES]] : vector<2x2xf32> -func.func @extract_contract4(%arg0: vector<2x2xf32>, - %arg1: vector<2x2xf32>, - %arg2: vector<2x2xf32>) -> vector<2x2xf32> { - %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 +func.func @contract_to_dot_matmat(%lhs: vector<2x2xf32>, + %rhs: vector<2x2xf32>, + %init: vector<2x2xf32>) -> vector<2x2xf32> { + %res = vector.contract #matmat_trait %lhs, %rhs, %init : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - return %0 : vector<2x2xf32> + return %res : vector<2x2xf32> }