Skip to content

[mlir][sparse] code cleanup (using inferred type to construct to_[buf… #83361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,11 @@ static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
if (format == CuSparseFormat::kCOO) {
// Library uses SoA COO, direct IR uses AoS COO.
if (enableRT)
return genToCoordinates(builder, loc, a, 0);
return genToCoordinatesBuffer(builder, loc, a);
return builder.create<ToCoordinatesOp>(loc, a, 0);
return builder.create<ToCoordinatesBufferOp>(loc, a);
}
// Formats CSR/CSC and BSR use positions at 1.
return genToPositions(builder, loc, a, 1);
return builder.create<ToPositionsOp>(loc, a, 1);
}

/// Generates the second coordinates of a sparse matrix.
Expand All @@ -510,7 +510,7 @@ static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
if (isCOO && !enableRT)
return Value(); // nothing needed
// Formats CSR/CSC and BSR use coordinates at 1.
return genToCoordinates(builder, loc, a, 1);
return builder.create<ToCoordinatesOp>(loc, a, 1);
}

/// Generates the sparse matrix handle.
Expand Down Expand Up @@ -584,7 +584,7 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
Value memV = genToValues(rewriter, loc, a);
Value memV = rewriter.create<ToValuesOp>(loc, a);
Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
Value valA = genAllocCopy(rewriter, loc, memV, tokens);
Expand Down Expand Up @@ -682,7 +682,7 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
Value memV = genToValues(rewriter, loc, a);
Value memV = rewriter.create<ToValuesOp>(loc, a);
Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
Value valA = genAllocCopy(rewriter, loc, memV, tokens);
Expand Down Expand Up @@ -785,10 +785,10 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty
Value amemV = genToValues(rewriter, loc, a);
Value amemV = rewriter.create<ToValuesOp>(loc, a);
Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty
Value bmemV = genToValues(rewriter, loc, b);
Value bmemV = rewriter.create<ToValuesOp>(loc, b);
Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
Expand Down Expand Up @@ -1081,7 +1081,7 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty
Value memV = genToValues(rewriter, loc, c);
Value memV = rewriter.create<ToValuesOp>(loc, c);
Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
Value valC = genAllocCopy(rewriter, loc, memV, tokens);
Expand Down
35 changes: 0 additions & 35 deletions mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,41 +554,6 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
.getResult();
}

Value sparse_tensor::genToPositions(OpBuilder &builder, Location loc,
Value tensor, Level lvl) {
const auto srcTp = getSparseTensorType(tensor);
const Type posTp = srcTp.getPosType();
const Type memTp = get1DMemRefType(posTp, /*withLayout=*/false);
return builder.create<ToPositionsOp>(loc, memTp, tensor,
builder.getIndexAttr(lvl));
}

Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
Value tensor, Level lvl) {
const auto srcTp = getSparseTensorType(tensor);
const Type crdTp = srcTp.getCrdType();
const Type memTp =
get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getAoSCOOStart());
return builder.create<ToCoordinatesOp>(loc, memTp, tensor,
builder.getIndexAttr(lvl));
}

Value sparse_tensor::genToCoordinatesBuffer(OpBuilder &builder, Location loc,
Value tensor) {
const auto srcTp = getSparseTensorType(tensor);
const Type crdTp = srcTp.getCrdType();
const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/false);
return builder.create<ToCoordinatesBufferOp>(loc, memTp, tensor);
}

Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
Value tensor) {
RankedTensorType srcTp = getRankedTensorType(tensor);
Type valTp = get1DMemRefType(srcTp.getElementType(),
/*withLayout=*/false);
return builder.create<ToValuesOp>(loc, valTp, tensor);
}

Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
Value tensor) {
return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
Expand Down
27 changes: 0 additions & 27 deletions mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,6 @@ void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer);
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
Location loc, Value src);

/// Generates a 1D MemRefType with a dynamic size. When withLayout is set, the
/// returned memref has a layout has unknown strides and offsets. Otherwise,
/// a memref with a standard unit stride zero offset layout is returned.
inline MemRefType get1DMemRefType(Type etp, bool withLayout) {
auto layout = withLayout ? StridedLayoutAttr::StridedLayoutAttr::get(
etp.getContext(), ShapedType::kDynamic,
{ShapedType::kDynamic})
: StridedLayoutAttr();
return MemRefType::get(ShapedType::kDynamic, etp, layout);
}

/// Scans to top of generated loop.
Operation *getTop(Operation *op);

Expand Down Expand Up @@ -281,22 +270,6 @@ void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
Value tensor);

/// Infers the result type and generates `ToPositionsOp`.
Value genToPositions(OpBuilder &builder, Location loc, Value tensor, Level lvl);

/// Infers the result type and generates `ToCoordinatesOp`. If the
/// level is within a COO region, the result type is a memref with unknown
/// stride and offset. Otherwise, the result type is a memref without
/// any specified layout.
Value genToCoordinates(OpBuilder &builder, Location loc, Value tensor,
Level lvl);

/// Infers the result type and generates `ToCoordinatesBufferOp`.
Value genToCoordinatesBuffer(OpBuilder &builder, Location loc, Value tensor);

/// Infers the result type and generates `ToValuesOp`.
Value genToValues(OpBuilder &builder, Location loc, Value tensor);

/// Generates code to retrieve the values size for the sparse tensor.
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ void LoopEmitter::initializeLoopEmit(
// Annotated sparse tensors.
// We also need the value buffer for all-dense annotated "sparse"
// tensors.
valBuffer[t] = genToValues(builder, loc, tensor);
valBuffer[t] = builder.create<ToValuesOp>(loc, tensor);
}
// NOTE: we can also prepare for 0 lvl here in advance, this will hoist
// some loop preparation from tensor iteration, but will also (undesirably)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1281,21 +1281,21 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
case LevelFormat::Batch:
llvm_unreachable("not implemented");
case LevelFormat::Compressed: {
Value pos = genToPositions(b, l, t, lvl);
Value crd = genToCoordinates(b, l, t, lvl);
Value pos = b.create<ToPositionsOp>(l, t, lvl);
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
}
case LevelFormat::LooseCompressed: {
Value pos = genToPositions(b, l, t, lvl);
Value crd = genToCoordinates(b, l, t, lvl);
Value pos = b.create<ToPositionsOp>(l, t, lvl);
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
}
case LevelFormat::Singleton: {
Value crd = genToCoordinates(b, l, t, lvl);
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
}
case LevelFormat::NOutOfM: {
Value crd = genToCoordinates(b, l, t, lvl);
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
}
case LevelFormat::Undef:
Expand Down