diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index f1c3d717f1fa9..c8f0806e27a62 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1819,7 +1819,7 @@ def TileUsingForOp : Op:$dynamic_sizes, DefaultValuedOptionalAttr:$static_sizes, - DefaultValuedOptionalAttr:$interchange, + DefaultValuedOptionalAttr:$interchange, DefaultValuedOptionalAttr:$scalable_sizes); let results = (outs TransformHandleTypeInterface:$tiled_linalg_op, Variadic:$loops); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index de4965f937162..73de3f22d896f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2477,7 +2477,7 @@ void transform::TileUsingForOp::build( /*target=*/target, /*dynamic_sizes=*/dynamicTileSizes, /*static_sizes=*/staticTileSizesAttr, - /*interchange=*/builder.getDenseI64ArrayAttr(interchange), + /*interchange=*/builder.getI64ArrayAttr(interchange), /*scalable_sizes=*/expandedScalableSizes); } @@ -2611,7 +2611,8 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter, }); } - tilingOptions.setInterchange(getInterchange()); + tilingOptions.setInterchange( + extractFromIntegerArrayAttr(getInterchange())); FailureOr maybeTilingResult = tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions); if (failed(maybeTilingResult)) @@ -2648,33 +2649,6 @@ SmallVector transform::TileUsingForOp::getMixedSizes() { return results; } -// We want to parse `DenseI64ArrayAttr` using the short form without the -// `array` prefix to be consistent in the IR with `parseDynamicIndexList`. -ParseResult parseOptionalInterchange(OpAsmParser &parser, - OperationState &result) { - if (succeeded(parser.parseOptionalLBrace())) { - if (failed(parser.parseKeyword("interchange"))) - return parser.emitError(parser.getNameLoc()) << "expect `interchange`"; - if (failed(parser.parseEqual())) - return parser.emitError(parser.getNameLoc()) << "expect `=`"; - result.addAttribute("interchange", - DenseI64ArrayAttr::parse(parser, Type{})); - if (failed(parser.parseRBrace())) - return parser.emitError(parser.getNameLoc()) << "expect `}`"; - } - return success(); -} - -void printOptionalInterchange(OpAsmPrinter &p, - ArrayRef interchangeVals) { - if (!interchangeVals.empty()) { - p << " {interchange = ["; - llvm::interleaveComma(interchangeVals, p, - [&](int64_t integer) { p << integer; }); - p << "]}"; - } -} - ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target; @@ -2686,7 +2660,7 @@ ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser, if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) || parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) || - parseOptionalInterchange(parser, result) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(functionalType)) return ParseResult::failure(); @@ -2720,7 +2694,10 @@ void TileUsingForOp::print(OpAsmPrinter &p) { printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(), /*valueTypes=*/{}, getScalableSizesAttr(), OpAsmParser::Delimiter::Square); - printOptionalInterchange(p, getInterchange()); + p.printOptionalAttrDict( + (*this)->getAttrs(), + /*elidedAttrs=*/{getScalableSizesAttrName(getOperation()->getName()), + getStaticSizesAttrName(getOperation()->getName())}); p << " : "; p.printFunctionalType(getOperands().getTypes(), getResults().getTypes()); } diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir index e9f044be5b4ed..4d7c514dcca62 100644 --- a/mlir/test/Dialect/Linalg/transform-ops.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops.mlir @@ -6,6 +6,14 @@ transform.sequence failures(propagate) { %0, %1:2 = transform.structured.tile_using_for %arg0 [2, 0, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) } +// check that the Attributes of `tile_using_for` are preserved through printing +// and parsing. +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + // CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile %arg0 [2, 0, 3] {interchange = [2, 1], test_attr1 = 1 : i64, test_attr2} + %0, %1:2 = transform.structured.tile_using_for %arg0 [2, 0, 3] {test_attr1 = 1 : i64, interchange = [2, 1], test_attr2}: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) +} + transform.sequence failures(propagate) { ^bb1(%arg0: !transform.any_op): %0:2 = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op