From e1899e35495edf054c8a7c9e06ab658bbe722dcf Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Sun, 25 Aug 2024 10:54:25 -0700 Subject: [PATCH] [MLIR][DLTI] Enable types as keys in DLTI-query utils Enable support for query functions - including transform.dlti.query - to take types as keys. As the data layout specific attributes already supported types as keys, this change enables querying such attributes in the expected way. --- mlir/include/mlir/Dialect/DLTI/DLTI.h | 2 +- .../DLTI/TransformOps/DLTITransformOps.td | 9 +- mlir/lib/Dialect/DLTI/DLTI.cpp | 32 ++++- .../DLTI/TransformOps/DLTITransformOps.cpp | 11 +- mlir/test/Dialect/DLTI/invalid.mlir | 8 ++ mlir/test/Dialect/DLTI/query.mlir | 120 ++++++++++++++++++ mlir/test/Dialect/DLTI/valid.mlir | 15 +++ 7 files changed, 187 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h index a97eb523cb063..f268fea340a6f 100644 --- a/mlir/include/mlir/Dialect/DLTI/DLTI.h +++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h @@ -26,7 +26,7 @@ namespace mlir { namespace dlti { /// Perform a DLTI-query at `op`, recursively querying each key of `keys` on /// query interface-implementing attrs, starting from attr obtained from `op`. -FailureOr query(Operation *op, ArrayRef keys, +FailureOr query(Operation *op, ArrayRef keys, bool emitError = false); } // namespace dlti } // namespace mlir diff --git a/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td b/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td index 1b1bebfaab4e3..f25bb383912d4 100644 --- a/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td +++ b/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td @@ -26,9 +26,10 @@ def QueryOp : Op dlti::query(Operation *op, ArrayRef keys, - bool emitError) { +FailureOr +dlti::query(Operation *op, ArrayRef keys, bool emitError) { + if (keys.empty()) { + if (emitError) { + auto diag = op->emitError() << "target op of failed DLTI query"; + diag.attachNote(op->getLoc()) << "no keys provided to attempt query with"; + } + return failure(); + } + auto [queryable, queryOp] = getClosestQueryable(op); Operation *reportOp = (queryOp ? queryOp : op); @@ -438,6 +446,15 @@ FailureOr dlti::query(Operation *op, ArrayRef keys, return failure(); } + auto keyToStr = [](DataLayoutEntryKey key) -> std::string { + std::string buf; + llvm::TypeSwitch(key) + .Case( // The only two kinds of key we know of. + [&](auto key) { llvm::raw_string_ostream(buf) << key; }) + .Default([](auto) { llvm_unreachable("unexpected entry key kind"); }); + return buf; + }; + Attribute currentAttr = queryable; for (auto &&[idx, key] : llvm::enumerate(keys)) { if (auto map = llvm::dyn_cast(currentAttr)) { @@ -446,17 +463,24 @@ FailureOr dlti::query(Operation *op, ArrayRef keys, if (emitError) { auto diag = op->emitError() << "target op of failed DLTI query"; diag.attachNote(reportOp->getLoc()) - << "key " << key << " has no DLTI-mapping per attr: " << map; + << "key " << keyToStr(key) + << " has no DLTI-mapping per attr: " << map; } return failure(); } currentAttr = *maybeAttr; } else { if (emitError) { + std::string commaSeparatedKeys; + llvm::interleave( + keys.take_front(idx), // All prior keys. + [&](auto key) { commaSeparatedKeys += keyToStr(key); }, + [&]() { commaSeparatedKeys += ","; }); + auto diag = op->emitError() << "target op of failed DLTI query"; diag.attachNote(reportOp->getLoc()) << "got non-DLTI-queryable attribute upon looking up keys [" - << keys.take_front(idx) << "] at op"; + << commaSeparatedKeys << "] at op"; } return failure(); } diff --git a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp index 90aef82bddff0..02c41b4fe8113 100644 --- a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp +++ b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp @@ -33,7 +33,16 @@ void transform::QueryOp::getEffects( DiagnosedSilenceableFailure transform::QueryOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, TransformState &state) { - auto keys = SmallVector(getKeys().getAsRange()); + SmallVector keys; + for (Attribute key : getKeys()) { + if (auto strKey = dyn_cast(key)) + keys.push_back(strKey); + else if (auto typeKey = dyn_cast(key)) + keys.push_back(typeKey.getValue()); + else + return emitDefiniteFailure("'transform.dlti.query' keys of wrong type: " + "only StringAttr and TypeAttr are allowed"); + } FailureOr result = dlti::query(target, keys, /*emitError=*/true); diff --git a/mlir/test/Dialect/DLTI/invalid.mlir b/mlir/test/Dialect/DLTI/invalid.mlir index 05f919fa25671..4b04f0195ef82 100644 --- a/mlir/test/Dialect/DLTI/invalid.mlir +++ b/mlir/test/Dialect/DLTI/invalid.mlir @@ -33,6 +33,14 @@ // ----- +// expected-error@below {{repeated layout entry key: 'i32'}} +"test.unknown_op"() { test.unknown_attr = #dlti.map< + #dlti.dl_entry, + #dlti.dl_entry +>} : () -> () + +// ----- + // expected-error@below {{repeated layout entry key: 'i32'}} "test.unknown_op"() { test.unknown_attr = #dlti.dl_spec< #dlti.dl_entry, diff --git a/mlir/test/Dialect/DLTI/query.mlir b/mlir/test/Dialect/DLTI/query.mlir index 10e91afd2ca7e..a793c1a6e8e6a 100644 --- a/mlir/test/Dialect/DLTI/query.mlir +++ b/mlir/test/Dialect/DLTI/query.mlir @@ -17,6 +17,60 @@ module attributes {transform.with_named_sequence} { // ----- +// expected-remark @below {{i32 present in set : unit}} +module attributes { test.dlti = #dlti.map<#dlti.dl_entry>} { + func.func private @f() +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op + %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op + %param = transform.dlti.query [i32] at %module : (!transform.any_op) -> !transform.any_param + transform.debug.emit_param_as_remark %param, "i32 present in set :" at %module : !transform.any_param, !transform.any_op + transform.yield + } +} + +// ----- + +// expected-remark @below {{associated attr 32 : i32}} +module attributes { test.dlti = #dlti.map<#dlti.dl_entry>>>} { + func.func private @f() +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op + %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op + %param = transform.dlti.query [i32,"width_in_bits"] at %module : (!transform.any_op) -> !transform.any_param + transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op + transform.yield + } +} + +// ----- + +// expected-remark @below {{width in bits of i32 = 32 : i64}} +// expected-remark @below {{width in bits of f64 = 64 : i64}} +module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry, #dlti.dl_entry>>>} { + func.func private @f() +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op + %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op + %i32bits = transform.dlti.query ["width_in_bits",i32] at %module : (!transform.any_op) -> !transform.any_param + %f64bits = transform.dlti.query ["width_in_bits",f64] at %module : (!transform.any_op) -> !transform.any_param + transform.debug.emit_param_as_remark %i32bits, "width in bits of i32 =" at %module : !transform.any_param, !transform.any_op + transform.debug.emit_param_as_remark %f64bits, "width in bits of f64 =" at %module : !transform.any_param, !transform.any_op + transform.yield + } +} + +// ----- + // expected-remark @below {{associated attr 42 : i32}} module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} { func.func private @f() @@ -336,6 +390,23 @@ module attributes {transform.with_named_sequence} { // ----- +// expected-note @below {{got non-DLTI-queryable attribute upon looking up keys [i32]}} +module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry>} { + // expected-error @below {{target op of failed DLTI query}} + func.func private @f() +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op + // expected-error @below {{'transform.dlti.query' op failed to apply}} + %param = transform.dlti.query [i32,"width_in_bits"] at %func : (!transform.any_op) -> !transform.any_param + transform.yield + } +} + +// ----- + module { // expected-error @below {{target op of failed DLTI query}} // expected-note @below {{no DLTI-queryable attrs on target op or any of its ancestors}} @@ -353,6 +424,55 @@ module attributes {transform.with_named_sequence} { // ----- +// expected-note @below {{key i64 has no DLTI-mapping per attr: #dlti.map<#dlti.dl_entry>}} +module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry>>>} { + // expected-error @below {{target op of failed DLTI query}} + func.func private @f() +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op + // expected-error @below {{'transform.dlti.query' op failed to apply}} + %param = transform.dlti.query ["width_in_bits",i64] at %func : (!transform.any_op) -> !transform.any_param + transform.yield + } +} + +// ----- + +module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} { + func.func private @f() +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op + // expected-error @below {{'transform.dlti.query' keys of wrong type: only StringAttr and TypeAttr are allowed}} + %param = transform.dlti.query [1] at %funcs : (!transform.any_op) -> !transform.param + transform.yield + } +} + +// ----- + +module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"test.id", 42 : i32>>} { + // expected-error @below {{target op of failed DLTI query}} + // expected-note @below {{no keys provided to attempt query with}} + func.func private @f() +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op + // expected-error @below {{'transform.dlti.query' op failed to apply}} + %param = transform.dlti.query [] at %func : (!transform.any_op) -> !transform.any_param + transform.yield + } +} + +// ----- + module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} { func.func private @f() } diff --git a/mlir/test/Dialect/DLTI/valid.mlir b/mlir/test/Dialect/DLTI/valid.mlir index 4133eac5424ce..31c925e5cb5be 100644 --- a/mlir/test/Dialect/DLTI/valid.mlir +++ b/mlir/test/Dialect/DLTI/valid.mlir @@ -206,3 +206,18 @@ module attributes { "GPU": #dlti.target_device_spec< #dlti.dl_entry<"L1_cache_size_in_bytes", "128">> >} {} + + +// ----- + +// CHECK: "test.op_with_dlti_map"() ({ +// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42 : i64>>} +"test.op_with_dlti_map"() ({ +}) { dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42>> } : () -> () + +// ----- + +// CHECK: "test.op_with_dlti_map"() ({ +// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry>} +"test.op_with_dlti_map"() ({ +}) { dlti.map = #dlti.map<#dlti.dl_entry> } : () -> ()