diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index d5c65b23e3a21..475368f0f406a 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -311,6 +311,9 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { auto readOnlyName = StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName); + // Collect symbols missing in the block. + SmallVector missingSymbols; + LLVM_DEBUG(DBGS() << "searching block for missing symbols:\n"); for (Operation &op : llvm::make_early_inc_range(block)) { LLVM_DEBUG(DBGS() << op << "\n"); auto symbol = dyn_cast(op); @@ -318,25 +321,33 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { continue; if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty()) continue; + LLVM_DEBUG(DBGS() << " -> symbol missing\n"); + missingSymbols.push_back(symbol); + } - LLVM_DEBUG(DBGS() << "looking for definition of symbol " - << symbol.getNameAttr() << ":"); - SymbolTable symbolTable(definitions); - Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr()); + // Resolve missing symbols until they are all resolved. + while (!missingSymbols.empty()) { + SymbolOpInterface symbol = missingSymbols.pop_back_val(); + LLVM_DEBUG(DBGS() << "looking for definition of symbol @" + << symbol.getNameAttr().getValue() << ": "); + SymbolTable definitionsSymbolTable(definitions); + Operation *externalSymbol = + definitionsSymbolTable.lookup(symbol.getNameAttr()); if (!externalSymbol || externalSymbol->getNumRegions() != 1 || externalSymbol->getRegion(0).empty()) { LLVM_DEBUG(llvm::dbgs() << "not found\n"); continue; } - auto symbolFunc = dyn_cast(op); + auto symbolFunc = dyn_cast(symbol.getOperation()); auto externalSymbolFunc = dyn_cast(externalSymbol); if (!symbolFunc || !externalSymbolFunc) { LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n"); continue; } - LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n"); + LLVM_DEBUG(llvm::dbgs() << "found " << externalSymbol << " from " + << externalSymbol->getLoc() << "\n"); if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) { return symbolFunc.emitError() << "external definition has a mismatching signature (" @@ -367,10 +378,53 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { } } - OpBuilder builder(&op); - builder.setInsertionPoint(&op); - builder.clone(*externalSymbol); + OpBuilder builder(symbol); + builder.setInsertionPoint(symbol); + Operation *newSymbol = builder.clone(*externalSymbol); + builder.setInsertionPoint(newSymbol); symbol->erase(); + + LLVM_DEBUG(DBGS() << "scanning definition of @" + << externalSymbolFunc.getNameAttr().getValue() + << " for symbol usages\n"); + externalSymbolFunc.walk([&](CallOpInterface callOp) { + LLVM_DEBUG(DBGS() << " call op in:\n" << callOp << "\n"); + CallInterfaceCallable callable = callOp.getCallableForCallee(); + if (!isa(callable)) { + LLVM_DEBUG(DBGS() << " not a symbol usage\n"); + return WalkResult::advance(); + } + + StringRef callableSymbolName = + cast(callable).getLeafReference(); + LLVM_DEBUG(DBGS() << " looking for @" << callableSymbolName + << " in definitions: "); + + Operation *callableOp = definitionsSymbolTable.lookup(callableSymbolName); + if (!isa(callable)) { + LLVM_DEBUG(llvm::dbgs() << "not found\n"); + return WalkResult::advance(); + } + LLVM_DEBUG(llvm::dbgs() << "found " << callableOp << " from " + << callableOp->getLoc() << "\n"); + + if (!block.getParent() || !block.getParent()->getParentOp()) { + LLVM_DEBUG(DBGS() << "could not get parent op of provided block"); + return WalkResult::advance(); + } + + SymbolTable targetSymbolTable(block.getParent()->getParentOp()); + if (targetSymbolTable.lookup(callableSymbolName)) { + LLVM_DEBUG(DBGS() << " symbol @" << callableSymbolName + << " already present in target\n"); + return WalkResult::advance(); + } + + LLVM_DEBUG(DBGS() << " cloning op into target\n"); + builder.clone(*callableOp); + + return WalkResult::advance(); + }); } return success(); diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir new file mode 100644 index 0000000000000..0e9fa7c59bc41 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ +// RUN: --verify-diagnostics --split-input-file | FileCheck %s + +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \ +// RUN: --verify-diagnostics --split-input-file | FileCheck %s + +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ +// RUN: --verify-diagnostics --split-input-file | FileCheck %s + +// The definition of the @bar named sequence is provided in another file. It +// will be included because of the pass option. That sequence uses another named +// sequence @foo, which should be made available here. Repeated application of +// the same pass, with or without the library option, should not be a problem. +// Note that the same diagnostic produced twice at the same location only +// needs to be matched once. + +// expected-remark @below {{message}} +module attributes {transform.with_named_sequence} { + // CHECK-DAG: transform.named_sequence @foo + // CHECK-DAG: transform.named_sequence @bar + transform.named_sequence private @bar(!transform.any_op {transform.readonly}) + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + include @bar failures(propagate) (%arg0) : (!transform.any_op) -> () + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir index 1149bda98ab85..9aa2d46d5abb9 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir @@ -1,6 +1,11 @@ // RUN: mlir-opt %s module attributes {transform.with_named_sequence} { + transform.named_sequence @bar(%arg0: !transform.any_op) { + transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.yield + } + transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) { transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op transform.yield