Skip to content

Conversation

yinying-lisa-li
Copy link
Contributor

@yinying-lisa-li yinying-lisa-li commented Apr 25, 2024

  1. Verify that the type of explicit/implicit values should be the same as the tensor element type.
  2. Verify that implicit value could only be zero.
  3. Verify that explicit/implicit values should be numeric.
  4. Fix the type change issue caused by SparseTensorType(enc).

@yinying-lisa-li yinying-lisa-li marked this pull request as ready for review April 25, 2024 21:26
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Apr 25, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 25, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Yinying Li (yinying-lisa-li)

Changes
  1. Verify that the type of explicit/implicit values should be the same as the tensor element type.
  2. Verify that implicit value could only be zero.
  3. Verify that explicit/implicit values should be numeric.

Full diff: https://github.com/llvm/llvm-project/pull/90111.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+5)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+36)
  • (modified) mlir/test/Dialect/SparseTensor/invalid_encoding.mlir (+72)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index eefa4c71bbd2ca..37fa4913aa6a60 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -512,6 +512,11 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     void printSymbols(AffineMap &map, AsmPrinter &printer) const;
     void printDimensions(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
     void printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::LevelType> lvlTypes) const;
+
+    //
+    // Explicit/implicit value methods.
+    //
+    Type getMismatchedValueType(Type elementType, Attribute val) const;
   }];
 
   let genVerifyDecl = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 028a69da10c1e1..7c938ecaed5abe 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -888,6 +888,19 @@ LogicalResult SparseTensorEncodingAttr::verify(
   return success();
 }
 
+Type SparseTensorEncodingAttr::getMismatchedValueType(Type elementType,
+                                                      Attribute val) const {
+  Type type;
+  auto fVal = llvm::dyn_cast<FloatAttr>(val);
+  auto intVal = llvm::dyn_cast<IntegerAttr>(val);
+  if (fVal && fVal.getType() != elementType) {
+    type = fVal.getType();
+  } else if (intVal && intVal.getType() != elementType) {
+    type = intVal.getType();
+  }
+  return type;
+}
+
 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
     ArrayRef<Size> dimShape, Type elementType,
     function_ref<InFlightDiagnostic()> emitError) const {
@@ -907,6 +920,29 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
     return emitError()
            << "dimension-rank mismatch between encoding and tensor shape: "
            << getDimRank() << " != " << dimRank;
+  Type type;
+  if (getExplicitVal()) {
+    if ((type = getMismatchedValueType(elementType, getExplicitVal()))) {
+      return emitError() << "explicit value type mismatch between encoding and "
+                         << "tensor element type: " << type
+                         << " != " << elementType;
+    }
+  }
+  if (getImplicitVal()) {
+    auto impVal = getImplicitVal();
+    if ((type = getMismatchedValueType(elementType, impVal))) {
+      return emitError() << "implicit value type mismatch between encoding and "
+                         << "tensor element type: " << type
+                         << " != " << elementType;
+    }
+    // Currently, we only support zero as the implicit value.
+    auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
+    auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
+    if ((impFVal && impFVal.getValueAsDouble() != 0.0) ||
+        (impIntVal && impIntVal.getInt() != 0)) {
+      return emitError() << "implicit value must be zero";
+    }
+  }
   return success();
 }
 
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 8096c010ac935a..19e8fc95e22813 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -443,3 +443,75 @@ func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
 func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
   return
 }
+
+// -----
+
+#CSR_ExpType = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 32,
+  crdWidth = 32,
+  explicitVal = 1 : i32,
+  implicitVal = 0.0 : f32
+}>
+
+// expected-error@+1 {{explicit value type mismatch between encoding and tensor element type: 'i32' != 'f32'}}
+func.func private @sparse_csr(tensor<?x?xf32, #CSR_ExpType>)
+
+// -----
+
+#CSR_ImpType = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 32,
+  crdWidth = 32,
+  explicitVal = 1 : i32,
+  implicitVal = 0.0 : f32
+}>
+
+// expected-error@+1 {{implicit value type mismatch between encoding and tensor element type: 'f32' != 'i32'}}
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpType>)
+
+// -----
+
+// expected-error@+1 {{expected a numeric value for explicitVal}}
+#CSR_ExpType = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 32,
+  crdWidth = 32,
+  explicitVal = "str"
+}>
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ExpType>)
+
+// -----
+
+// expected-error@+1 {{expected a numeric value for implicitVal}}
+#CSR_ImpType = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 32,
+  crdWidth = 32,
+  implicitVal = "str"
+}>
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpType>)
+
+// -----
+
+#CSR_ImpVal = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 32,
+  crdWidth = 32,
+  implicitVal = 1 : i32
+}>
+
+// expected-error@+1 {{implicit value must be zero}}
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpVal>)
+
+// -----
+
+#CSR_ImpVal = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 32,
+  crdWidth = 32,
+  implicitVal = 1.0 : f32
+}>
+
+// expected-error@+1 {{implicit value must be zero}}
+func.func private @sparse_csr(tensor<?x?xf32, #CSR_ImpVal>)

Copy link

github-actions bot commented May 2, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@yinying-lisa-li yinying-lisa-li merged commit 83f3b1c into llvm:main May 8, 2024
@yinying-lisa-li yinying-lisa-li deleted the verify branch May 8, 2024 00:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants