diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 14404d837ff74..18ee36efab9d8 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2658,26 +2658,23 @@ SmallVector transform::TileUsingForOp::getMixedSizes() { // `array` prefix to be consistent in the IR with `parseDynamicIndexList`. ParseResult parseOptionalInterchange(OpAsmParser &parser, OperationState &result) { - if (succeeded(parser.parseOptionalLBrace())) { - if (failed(parser.parseKeyword("interchange"))) - return parser.emitError(parser.getNameLoc()) << "expect `interchange`"; - if (failed(parser.parseEqual())) - return parser.emitError(parser.getNameLoc()) << "expect `=`"; - result.addAttribute("interchange", - DenseI64ArrayAttr::parse(parser, Type{})); - if (failed(parser.parseRBrace())) - return parser.emitError(parser.getNameLoc()) << "expect `}`"; - } + if (failed(parser.parseOptionalKeyword("interchange"))) + return success(); + if (failed(parser.parseEqual())) + return failure(); + result.addAttribute( + transform::TileUsingForOp::getInterchangeAttrName(result.name), + DenseI64ArrayAttr::parse(parser, Type{})); return success(); } void printOptionalInterchange(OpAsmPrinter &p, ArrayRef interchangeVals) { if (!interchangeVals.empty()) { - p << " {interchange = ["; + p << " interchange = ["; llvm::interleaveComma(interchangeVals, p, [&](int64_t integer) { p << integer; }); - p << "]}"; + p << "]"; } } @@ -2693,6 +2690,7 @@ ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser, if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) || parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) || parseOptionalInterchange(parser, result) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(functionalType)) return ParseResult::failure(); @@ -2727,6 +2725,11 @@ void TileUsingForOp::print(OpAsmPrinter &p) { /*valueTypes=*/{}, getScalableSizesAttr(), OpAsmParser::Delimiter::Square); printOptionalInterchange(p, getInterchange()); + p.printOptionalAttrDict( + (*this)->getAttrs(), + /*elidedAttrs=*/{getInterchangeAttrName(getOperation()->getName()), + getScalableSizesAttrName(getOperation()->getName()), + getStaticSizesAttrName(getOperation()->getName())}); p << " : "; p.printFunctionalType(getOperands().getTypes(), getResults().getTypes()); } diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir index e9f044be5b4ed..6b276e69a595d 100644 --- a/mlir/test/Dialect/Linalg/transform-ops.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops.mlir @@ -6,6 +6,16 @@ transform.sequence failures(propagate) { %0, %1:2 = transform.structured.tile_using_for %arg0 [2, 0, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) } +// check that the Attributes of `tile_using_for` are preserved through printing +// and parsing with and without use of the optional `interchange` Attribute. +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + // CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile %arg0 [2, 0, 3] interchange = [2, 1] {test_attr1 = 1 : i64, test_attr2} + %0, %1:2 = transform.structured.tile_using_for %arg0 [2, 0, 3] interchange = [2, 1] {test_attr1 = 1 : i64, test_attr2}: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + // CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile %arg0 [4, 5, 3] {test_attr3 = 1 : i64, test_attr4} + %2, %3:2 = transform.structured.tile_using_for %0 [0, 5, 3] {test_attr3 = 1 : i64, test_attr4}: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) +} + transform.sequence failures(propagate) { ^bb1(%arg0: !transform.any_op): %0:2 = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index 77ce4d0b211f0..5a9b490c07ff2 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -170,7 +170,7 @@ func.func @matvec_perm(%A: memref>, module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.tile_using_for %0 [5, 6] {interchange = [1, 0]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %1, %loops:2 = transform.structured.tile_using_for %0 [5, 6] interchange = [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } } @@ -199,8 +199,8 @@ func.func @matmul_perm(%A: memref>, module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:3 = transform.structured.tile_using_for %0 [2000, 3000, 4000] {interchange = [1, 2, 0]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - %2, %loops_2:3 = transform.structured.tile_using_for %1 [200, 300, 400] {interchange = [1, 0, 2]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + %1, %loops:3 = transform.structured.tile_using_for %0 [2000, 3000, 4000] interchange = [1, 2, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + %2, %loops_2:3 = transform.structured.tile_using_for %1 [200, 300, 400] interchange = [1, 0, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) %3, %loops_3:3 = transform.structured.tile_using_for %2 [20, 30, 40] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield }