Skip to content

[mlir][sparse] code cleanup (using inferred type to construct to_[buf… #83361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 29, 2024

Conversation

PeimingLiu
Copy link
Member

…fer] op).

@llvmbot
Copy link
Member

llvmbot commented Feb 29, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

…fer] op).


Full diff: https://github.com/llvm/llvm-project/pull/83361.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (+4-19)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (-11)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index b888dfadb9c714..370818e0de2578 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -556,37 +556,22 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
 
 Value sparse_tensor::genToPositions(OpBuilder &builder, Location loc,
                                     Value tensor, Level lvl) {
-  const auto srcTp = getSparseTensorType(tensor);
-  const Type posTp = srcTp.getPosType();
-  const Type memTp = get1DMemRefType(posTp, /*withLayout=*/false);
-  return builder.create<ToPositionsOp>(loc, memTp, tensor,
-                                       builder.getIndexAttr(lvl));
+  return builder.create<ToPositionsOp>(loc, tensor, lvl);
 }
 
 Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
                                       Value tensor, Level lvl) {
-  const auto srcTp = getSparseTensorType(tensor);
-  const Type crdTp = srcTp.getCrdType();
-  const Type memTp =
-      get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getAoSCOOStart());
-  return builder.create<ToCoordinatesOp>(loc, memTp, tensor,
-                                         builder.getIndexAttr(lvl));
+  return builder.create<ToCoordinatesOp>(loc, tensor, lvl);
 }
 
 Value sparse_tensor::genToCoordinatesBuffer(OpBuilder &builder, Location loc,
                                             Value tensor) {
-  const auto srcTp = getSparseTensorType(tensor);
-  const Type crdTp = srcTp.getCrdType();
-  const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/false);
-  return builder.create<ToCoordinatesBufferOp>(loc, memTp, tensor);
+  return builder.create<ToCoordinatesBufferOp>(loc, tensor);
 }
 
 Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
                                  Value tensor) {
-  RankedTensorType srcTp = getRankedTensorType(tensor);
-  Type valTp = get1DMemRefType(srcTp.getElementType(),
-                               /*withLayout=*/false);
-  return builder.create<ToValuesOp>(loc, valTp, tensor);
+  return builder.create<ToValuesOp>(loc, tensor);
 }
 
 Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index cc119bc7045595..eb246fc7c45bbb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -228,17 +228,6 @@ void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer);
 void sizesFromSrc(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
                   Location loc, Value src);
 
-/// Generates a 1D MemRefType with a dynamic size. When withLayout is set, the
-/// returned memref has a layout has unknown strides and offsets. Otherwise,
-/// a memref with a standard unit stride zero offset layout is returned.
-inline MemRefType get1DMemRefType(Type etp, bool withLayout) {
-  auto layout = withLayout ? StridedLayoutAttr::StridedLayoutAttr::get(
-                                 etp.getContext(), ShapedType::kDynamic,
-                                 {ShapedType::kDynamic})
-                           : StridedLayoutAttr();
-  return MemRefType::get(ShapedType::kDynamic, etp, layout);
-}
-
 /// Scans to top of generated loop.
 Operation *getTop(Operation *op);
 

Copy link
Contributor

@aartbik aartbik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like we don't really need the util anymore either, they are longer than an actual build line ;-)

Copy link
Contributor

@aartbik aartbik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@PeimingLiu PeimingLiu merged commit 1a0986f into llvm:main Feb 29, 2024
@PeimingLiu PeimingLiu deleted the infer-memref branch February 29, 2024 00:55
mylai-mtk pushed a commit to mylai-mtk/llvm-project that referenced this pull request Jul 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants