Skip to content

Commit 720898c

Browse files
Address comments from @ftynse's review.
1 parent 1833cbb commit 720898c

File tree

1 file changed

+41
-26
lines changed

1 file changed

+41
-26
lines changed

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -617,34 +617,20 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
617617
extraMappings, options);
618618
}
619619

620-
LogicalResult transform::detail::interpreterBaseInitializeImpl(
621-
MLIRContext *context, StringRef transformFileName,
622-
ArrayRef<std::string> transformLibraryPaths,
623-
std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
624-
std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
625-
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
626-
moduleBuilder) {
627-
auto unknownLoc = UnknownLoc::get(context);
628-
629-
// Parse module from file.
630-
OwningOpRef<ModuleOp> moduleFromFile;
631-
{
632-
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
633-
if (failed(parseTransformModuleFromFile(context, transformFileName,
634-
moduleFromFile)))
635-
return emitError(loc) << "failed to parse transform module";
636-
if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
637-
return emitError(loc) << "failed to verify transform module";
638-
}
639-
640-
// Assemble list of library files.
641-
SmallVector<std::string> libraryFileNames;
642-
for (const std::string &path : transformLibraryPaths) {
643-
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
620+
/// Expands the given list of `paths` to a list of `.mlir` files.
621+
///
622+
/// Each entry in `paths` may either be a regular file, in which case it ends up
623+
/// in the result list, or a directory, in which case all (regular) `.mlir`
624+
/// files in that directory are added. Any other file types lead to a failure.
625+
static LogicalResult
626+
expandPathsToMLIRFiles(ArrayRef<std::string> &paths, MLIRContext *const context,
627+
SmallVectorImpl<std::string> &fileNames) {
628+
for (const std::string &path : paths) {
629+
auto loc = FileLineColLoc::get(context, path, 0, 0);
644630

645631
if (llvm::sys::fs::is_regular_file(path)) {
646632
LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
647-
libraryFileNames.push_back(path);
633+
fileNames.push_back(path);
648634
continue;
649635
}
650636

@@ -673,14 +659,43 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
673659
}
674660

675661
LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
676-
libraryFileNames.push_back(fileName);
662+
fileNames.push_back(fileName);
677663
}
678664

679665
if (ec)
680666
return emitError(loc) << "error while opening files in '" << path
681667
<< "': " << ec.message();
682668
}
683669

670+
return success();
671+
}
672+
673+
LogicalResult transform::detail::interpreterBaseInitializeImpl(
674+
MLIRContext *context, StringRef transformFileName,
675+
ArrayRef<std::string> transformLibraryPaths,
676+
std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
677+
std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
678+
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
679+
moduleBuilder) {
680+
auto unknownLoc = UnknownLoc::get(context);
681+
682+
// Parse module from file.
683+
OwningOpRef<ModuleOp> moduleFromFile;
684+
{
685+
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
686+
if (failed(parseTransformModuleFromFile(context, transformFileName,
687+
moduleFromFile)))
688+
return emitError(loc) << "failed to parse transform module";
689+
if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
690+
return emitError(loc) << "failed to verify transform module";
691+
}
692+
693+
// Assemble list of library files.
694+
SmallVector<std::string> libraryFileNames;
695+
if (failed(expandPathsToMLIRFiles(transformLibraryPaths, context,
696+
libraryFileNames)))
697+
return failure();
698+
684699
// Parse modules from library files.
685700
SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
686701
for (const std::string &libraryFileName : libraryFileNames) {

0 commit comments

Comments
 (0)