Skip to content

Commit 0b7362c

Browse files
authored
[mlir][arith] Add result pretty printing for constant vscale values (#83565)
In scalable code it is very common to have constant multiples of vscale, e.g. `4 * vscale`. This updates `arith.muli` to pretty print the result name in cases like this, so `4 * vscale` would be `%c4_vscale`. This makes reading IR dumps of scalable code a little nicer.
1 parent be8bc3c commit 0b7362c

File tree

3 files changed

+44
-1
lines changed

3 files changed

+44
-1
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,9 @@ def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
343343
// MulIOp
344344
//===----------------------------------------------------------------------===//
345345

346-
def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli", [Commutative]> {
346+
def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli",
347+
[Commutative, DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]
348+
> {
347349
let summary = [{
348350
Integer multiplication operation.
349351
}];

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,33 @@ OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
423423
[](const APInt &a, const APInt &b) { return a * b; });
424424
}
425425

426+
void arith::MulIOp::getAsmResultNames(
427+
function_ref<void(Value, StringRef)> setNameFn) {
428+
if (!isa<IndexType>(getType()))
429+
return;
430+
431+
// Match vector.vscale by name to avoid depending on the vector dialect (which
432+
// is a circular dependency).
433+
auto isVscale = [](Operation *op) {
434+
return op && op->getName().getStringRef() == "vector.vscale";
435+
};
436+
437+
IntegerAttr baseValue;
438+
auto isVscaleExpr = [&](Value a, Value b) {
439+
return matchPattern(a, m_Constant(&baseValue)) &&
440+
isVscale(b.getDefiningOp());
441+
};
442+
443+
if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
444+
return;
445+
446+
// Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
447+
SmallString<32> specialNameBuffer;
448+
llvm::raw_svector_ostream specialName(specialNameBuffer);
449+
specialName << 'c' << baseValue.getInt() << "_vscale";
450+
setNameFn(getResult(), specialName.str());
451+
}
452+
426453
void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
427454
MLIRContext *context) {
428455
patterns.add<MulIMulIConstant>(context);
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-opt %s | FileCheck %s
2+
3+
// Note: This test is checking value names (so deliberately is not using a regex match).
4+
5+
func.func @test_vscale_constant_names() {
6+
%vscale = vector.vscale
7+
%c8 = arith.constant 8 : index
8+
// CHECK: %c8_vscale = arith.muli
9+
%0 = arith.muli %vscale, %c8 : index
10+
%c10 = arith.constant 10 : index
11+
// CHECK: %c10_vscale = arith.muli
12+
%1 = arith.muli %c10, %vscale : index
13+
return
14+
}

0 commit comments

Comments
 (0)