Skip to content

Commit 1a0986f

Browse files
authored
[mlir][sparse] code cleanup (using inferred type to construct to_[buf… (#83361)
…fer] op).
1 parent f7a544d commit 1a0986f

File tree

5 files changed

+16
-78
lines changed

5 files changed

+16
-78
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -496,11 +496,11 @@ static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
496496
if (format == CuSparseFormat::kCOO) {
497497
// Library uses SoA COO, direct IR uses AoS COO.
498498
if (enableRT)
499-
return genToCoordinates(builder, loc, a, 0);
500-
return genToCoordinatesBuffer(builder, loc, a);
499+
return builder.create<ToCoordinatesOp>(loc, a, 0);
500+
return builder.create<ToCoordinatesBufferOp>(loc, a);
501501
}
502502
// Formats CSR/CSC and BSR use positions at 1.
503-
return genToPositions(builder, loc, a, 1);
503+
return builder.create<ToPositionsOp>(loc, a, 1);
504504
}
505505

506506
/// Generates the second coordinates of a sparse matrix.
@@ -510,7 +510,7 @@ static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
510510
if (isCOO && !enableRT)
511511
return Value(); // nothing needed
512512
// Formats CSR/CSC and BSR use coordinates at 1.
513-
return genToCoordinates(builder, loc, a, 1);
513+
return builder.create<ToCoordinatesOp>(loc, a, 1);
514514
}
515515

516516
/// Generates the sparse matrix handle.
@@ -584,7 +584,7 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
584584
Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
585585
Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
586586
Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
587-
Value memV = genToValues(rewriter, loc, a);
587+
Value memV = rewriter.create<ToValuesOp>(loc, a);
588588
Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
589589
Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
590590
Value valA = genAllocCopy(rewriter, loc, memV, tokens);
@@ -682,7 +682,7 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
682682
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
683683
Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
684684
Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
685-
Value memV = genToValues(rewriter, loc, a);
685+
Value memV = rewriter.create<ToValuesOp>(loc, a);
686686
Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
687687
Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
688688
Value valA = genAllocCopy(rewriter, loc, memV, tokens);
@@ -785,10 +785,10 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
785785
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
786786
Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
787787
Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty
788-
Value amemV = genToValues(rewriter, loc, a);
788+
Value amemV = rewriter.create<ToValuesOp>(loc, a);
789789
Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
790790
Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty
791-
Value bmemV = genToValues(rewriter, loc, b);
791+
Value bmemV = rewriter.create<ToValuesOp>(loc, b);
792792
Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
793793
Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
794794
Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
@@ -1081,7 +1081,7 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
10811081
Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
10821082
Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
10831083
Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty
1084-
Value memV = genToValues(rewriter, loc, c);
1084+
Value memV = rewriter.create<ToValuesOp>(loc, c);
10851085
Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
10861086
Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
10871087
Value valC = genAllocCopy(rewriter, loc, memV, tokens);

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp

-35
Original file line numberDiff line numberDiff line change
@@ -554,41 +554,6 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
554554
.getResult();
555555
}
556556

557-
Value sparse_tensor::genToPositions(OpBuilder &builder, Location loc,
558-
Value tensor, Level lvl) {
559-
const auto srcTp = getSparseTensorType(tensor);
560-
const Type posTp = srcTp.getPosType();
561-
const Type memTp = get1DMemRefType(posTp, /*withLayout=*/false);
562-
return builder.create<ToPositionsOp>(loc, memTp, tensor,
563-
builder.getIndexAttr(lvl));
564-
}
565-
566-
Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
567-
Value tensor, Level lvl) {
568-
const auto srcTp = getSparseTensorType(tensor);
569-
const Type crdTp = srcTp.getCrdType();
570-
const Type memTp =
571-
get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getAoSCOOStart());
572-
return builder.create<ToCoordinatesOp>(loc, memTp, tensor,
573-
builder.getIndexAttr(lvl));
574-
}
575-
576-
Value sparse_tensor::genToCoordinatesBuffer(OpBuilder &builder, Location loc,
577-
Value tensor) {
578-
const auto srcTp = getSparseTensorType(tensor);
579-
const Type crdTp = srcTp.getCrdType();
580-
const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/false);
581-
return builder.create<ToCoordinatesBufferOp>(loc, memTp, tensor);
582-
}
583-
584-
Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
585-
Value tensor) {
586-
RankedTensorType srcTp = getRankedTensorType(tensor);
587-
Type valTp = get1DMemRefType(srcTp.getElementType(),
588-
/*withLayout=*/false);
589-
return builder.create<ToValuesOp>(loc, valTp, tensor);
590-
}
591-
592557
Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
593558
Value tensor) {
594559
return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h

-27
Original file line numberDiff line numberDiff line change
@@ -228,17 +228,6 @@ void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer);
228228
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
229229
Location loc, Value src);
230230

231-
/// Generates a 1D MemRefType with a dynamic size. When withLayout is set, the
232-
/// returned memref has a layout has unknown strides and offsets. Otherwise,
233-
/// a memref with a standard unit stride zero offset layout is returned.
234-
inline MemRefType get1DMemRefType(Type etp, bool withLayout) {
235-
auto layout = withLayout ? StridedLayoutAttr::StridedLayoutAttr::get(
236-
etp.getContext(), ShapedType::kDynamic,
237-
{ShapedType::kDynamic})
238-
: StridedLayoutAttr();
239-
return MemRefType::get(ShapedType::kDynamic, etp, layout);
240-
}
241-
242231
/// Scans to top of generated loop.
243232
Operation *getTop(Operation *op);
244233

@@ -281,22 +270,6 @@ void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
281270
TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
282271
Value tensor);
283272

284-
/// Infers the result type and generates `ToPositionsOp`.
285-
Value genToPositions(OpBuilder &builder, Location loc, Value tensor, Level lvl);
286-
287-
/// Infers the result type and generates `ToCoordinatesOp`. If the
288-
/// level is within a COO region, the result type is a memref with unknown
289-
/// stride and offset. Otherwise, the result type is a memref without
290-
/// any specified layout.
291-
Value genToCoordinates(OpBuilder &builder, Location loc, Value tensor,
292-
Level lvl);
293-
294-
/// Infers the result type and generates `ToCoordinatesBufferOp`.
295-
Value genToCoordinatesBuffer(OpBuilder &builder, Location loc, Value tensor);
296-
297-
/// Infers the result type and generates `ToValuesOp`.
298-
Value genToValues(OpBuilder &builder, Location loc, Value tensor);
299-
300273
/// Generates code to retrieve the values size for the sparse tensor.
301274
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);
302275

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ void LoopEmitter::initializeLoopEmit(
259259
// Annotated sparse tensors.
260260
// We also need the value buffer for all-dense annotated "sparse"
261261
// tensors.
262-
valBuffer[t] = genToValues(builder, loc, tensor);
262+
valBuffer[t] = builder.create<ToValuesOp>(loc, tensor);
263263
}
264264
// NOTE: we can also prepare for 0 lvl here in advance, this will hoist
265265
// some loop preparation from tensor iteration, but will also (undesirably)

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -1281,21 +1281,21 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
12811281
case LevelFormat::Batch:
12821282
llvm_unreachable("not implemented");
12831283
case LevelFormat::Compressed: {
1284-
Value pos = genToPositions(b, l, t, lvl);
1285-
Value crd = genToCoordinates(b, l, t, lvl);
1284+
Value pos = b.create<ToPositionsOp>(l, t, lvl);
1285+
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
12861286
return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
12871287
}
12881288
case LevelFormat::LooseCompressed: {
1289-
Value pos = genToPositions(b, l, t, lvl);
1290-
Value crd = genToCoordinates(b, l, t, lvl);
1289+
Value pos = b.create<ToPositionsOp>(l, t, lvl);
1290+
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
12911291
return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
12921292
}
12931293
case LevelFormat::Singleton: {
1294-
Value crd = genToCoordinates(b, l, t, lvl);
1294+
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
12951295
return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
12961296
}
12971297
case LevelFormat::NOutOfM: {
1298-
Value crd = genToCoordinates(b, l, t, lvl);
1298+
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
12991299
return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
13001300
}
13011301
case LevelFormat::Undef:

0 commit comments

Comments
 (0)