diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index 3a487a3bd6a06..2b81d6cdc1eab 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -316,6 +316,12 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, if (auto load = cast.getDefiningOp()) { Value inv = load.getOperand(0); Value idx = load.getOperand(1); + // Swap non-invariant. + if (!isInvariantValue(inv, block)) { + inv = idx; + idx = load.getOperand(0); + } + // Inspect. if (isInvariantValue(inv, block)) { if (auto arg = llvm::dyn_cast(idx)) { if (isInvariantArg(arg, block) || !innermost) diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir index dfee2b1261b6c..e25c3a02f9127 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir @@ -1,4 +1,3 @@ -// FIXME: re-enable. // RUN: mlir-opt %s -sparsifier="vl=8" | FileCheck %s #Dense = #sparse_tensor.encoding<{ @@ -16,7 +15,7 @@ } // CHECK-LABEL: llvm.func @kernel_matvec -// C_HECK: llvm.intr.vector.reduce.fadd +// CHECK: llvm.intr.vector.reduce.fadd func.func @kernel_matvec(%arga: tensor, %argb: tensor, %argx: tensor) -> tensor {