diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 6473c92a91aa6..1222542ee39fd 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -912,22 +912,27 @@ Type ContractionOp::getExpectedMaskType() { unsigned numVecDims = lhsIdxMap.getNumDims(); SmallVector maskShape(numVecDims, ShapedType::kDynamic); + SmallVector maskShapeScalableDims(numVecDims, false); // Using the information in the indexing maps, extract the size of each // dimension in the vector.contract operation from the two input operands. - for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) + for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) { maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize; - for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) + maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] = + lhsType.getScalableDims()[dimIdx]; + } + for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) { maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize; + maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] = + rhsType.getScalableDims()[dimIdx]; + } assert(!ShapedType::isDynamicShape(maskShape) && "Mask shape couldn't be computed"); - // TODO: Extend the scalable vector type representation with a bit map. - assert(!lhsType.isScalable() && !rhsType.isScalable() && - "Scalable vectors are not supported yet"); return VectorType::get(maskShape, - IntegerType::get(lhsType.getContext(), /*width=*/1)); + IntegerType::get(lhsType.getContext(), /*width=*/1), + maskShapeScalableDims); } SmallVector ContractionOp::getTraitAttrNames() { diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index f00bc6e97b350..d41cee5ea67b0 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -979,3 +979,27 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) { %2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32> return } + +#matmat_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} +// CHECK-LABEL: func.func @contraction_masked_scalable( +// CHECK-SAME: %[[A:.*]]: vector<3x4xf32>, +// CHECK-SAME: %[[B:.*]]: vector<4x[8]xf32>, +// CHECK-SAME: %[[C:.*]]: vector<3x[8]xf32>, +// CHECK-SAME: %[[M:.*]]: vector<3x[8]x4xi1>) -> vector<3x[8]xf32> { +func.func @contraction_masked_scalable(%A: vector<3x4xf32>, + %B: vector<4x[8]xf32>, + %C: vector<3x[8]xf32>, + %M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> { + // CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32> + %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } + : vector<3x[8]x4xi1> -> vector<3x[8]xf32> + return %0 : vector<3x[8]xf32> +}