Skip to content

[mlir][SparseTensor][NFC] Pass tensor type to descriptor helper #116468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,10 +646,11 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
matchAndRewrite(LvlOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::optional<int64_t> lvl = op.getConstantLvlIndex();
if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
RankedTensorType srcType = op.getSource().getType();
if (!lvl || !getSparseTensorEncoding(srcType))
return failure();

auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);

rewriter.replaceOp(op, sz);
Expand All @@ -675,8 +676,9 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
assert(dstStt.hasSameDimToLvl(srcStt));

// We don't need a mutable descriptor here as we perform sorting in-place.
auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
op.getInputCoo().getType());
auto nnz = desc.getValMemSize(rewriter, op.getLoc());
auto crd = desc.getAOSMemRef();
auto val = desc.getValMemRef();

Expand Down Expand Up @@ -704,7 +706,8 @@ class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Simply lowers to specifer.get <field> operation.
auto desc = getDescriptorFromTensorTuple(adaptor.getSlice());
auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
op.getSlice().getType());
auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
op.getDim().getZExtValue());

Expand Down Expand Up @@ -762,7 +765,8 @@ class SparseTensorAllocConverter
Location loc = op.getLoc();
// Deal with copy.
if (op.getCopy()) {
auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
auto desc = getDescriptorFromTensorTuple(
adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
SmallVector<Value> fields;
fields.reserve(desc.getNumFields());
// Memcpy on memref fields.
Expand Down Expand Up @@ -868,7 +872,9 @@ class SparseTensorDeallocConverter
if (createDeallocs) {
// Replace the sparse tensor deallocation with field deallocations.
Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(
adaptor.getTensor(),
cast<RankedTensorType>(op.getTensor().getType()));
for (auto input : desc.getMemRefFields())
// Deallocate every buffer used to store the sparse tensor handler.
rewriter.create<memref::DeallocOp>(loc, input);
Expand All @@ -889,7 +895,8 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
matchAndRewrite(LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Prepare descriptor.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
op.getTensor().getType());
// Generate optional insertion finalization code.
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), desc);
Expand All @@ -909,7 +916,8 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
if (!getSparseTensorEncoding(op.getTensor().getType()))
return failure();
Location loc = op->getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
op.getTensor().getType());
const auto srcType = getSparseTensorType(op.getTensor());
Type eltType = srcType.getElementType();
Type boolType = rewriter.getIntegerType(1);
Expand Down Expand Up @@ -959,7 +967,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
op.getTensor().getType());
Value values = adaptor.getValues();
Value filled = adaptor.getFilled();
Value added = adaptor.getAdded();
Expand Down Expand Up @@ -1032,7 +1041,8 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
assert(stt.isIdentity() && "Run reinterpret-map before conversion.");

Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
auto desc =
getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
TypeRange flatSpTensorTps = desc.getFields().getTypes();
SmallVector<Value> params = llvm::to_vector(desc.getFields());
params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
Expand All @@ -1059,7 +1069,8 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = op.getLevel();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
op.getTensor().getType());
auto mem = desc.getPosMemRef(lvl);
auto size = desc.getPosMemSize(rewriter, loc, lvl);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
Expand All @@ -1081,7 +1092,8 @@ class SparseToCoordinatesConverter
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = op.getLevel();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
op.getTensor().getType());
auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
Expand All @@ -1106,7 +1118,8 @@ class SparseToCoordinatesBufferConverter
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
op.getTensor().getType());
auto mem = desc.getAOSMemRef();
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
Expand All @@ -1126,7 +1139,8 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
// The view is restricted to the actual size to ensure clients
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
op.getTensor().getType());
auto mem = desc.getValMemRef();
auto size = desc.getValMemSize(rewriter, loc);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
Expand Down Expand Up @@ -1172,7 +1186,8 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
// else:
// dst = memref.copy(src)
Location loc = op.getLoc();
auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
op.getSource().getType());
SmallVector<Value> fields;
foreachFieldAndTypeInSparseTensor(
SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
Expand Down Expand Up @@ -1236,7 +1251,8 @@ class SparseExtractSliceConverter
assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());

SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
op.getSource().getType());

auto newSpec = rewriter.create<StorageSpecifierInitOp>(
loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
Expand Down Expand Up @@ -1285,8 +1301,9 @@ class SparseNumberOfEntriesConverter
// Query memSizes for the actually stored values.
// FIXME: the nse value computed in this way might be wrong when there is
// any "loose_compressed" level.
rewriter.replaceOp(
op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
op.getTensor().getType());
rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
return success();
}
};
Expand Down Expand Up @@ -1415,7 +1432,8 @@ struct SparseDisassembleOpConverter
LogicalResult
matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
op.getTensor().getType());
Location loc = op.getLoc();
SmallVector<Value> retMem;
SmallVector<Value> retLen;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,6 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
.getResult();
}

Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
Value tensor) {
return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
}

Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
Value tensor, Dimension dim) {
auto enc = getSparseTensorEncoding(tensor.getType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,6 @@ void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
Value tensor);

/// Generates code to retrieve the values size for the sparse tensor.
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);

/// Generates code to retrieve the slice offset for the sparse tensor slice,
/// return a constant if the offset is statically known.
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,18 @@ inline Value genTuple(OpBuilder &builder, Location loc,
return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
}

inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
inline SparseTensorDescriptor
getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) {
auto tuple = getTuple(tensor);
SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
return SparseTensorDescriptor(stt, tuple.getInputs());
return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs());
}

inline MutSparseTensorDescriptor
getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields,
RankedTensorType type) {
auto tuple = getTuple(tensor);
fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
return MutSparseTensorDescriptor(stt, fields);
return MutSparseTensorDescriptor(SparseTensorType(type), fields);
}

} // namespace sparse_tensor
Expand Down
Loading