@@ -617,34 +617,20 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
617
617
extraMappings, options);
618
618
}
619
619
620
- LogicalResult transform::detail::interpreterBaseInitializeImpl (
621
- MLIRContext *context, StringRef transformFileName,
622
- ArrayRef<std::string> transformLibraryPaths,
623
- std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
624
- std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
625
- function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
626
- moduleBuilder) {
627
- auto unknownLoc = UnknownLoc::get (context);
628
-
629
- // Parse module from file.
630
- OwningOpRef<ModuleOp> moduleFromFile;
631
- {
632
- auto loc = FileLineColLoc::get (context, transformFileName, 0 , 0 );
633
- if (failed (parseTransformModuleFromFile (context, transformFileName,
634
- moduleFromFile)))
635
- return emitError (loc) << " failed to parse transform module" ;
636
- if (moduleFromFile && failed (mlir::verify (*moduleFromFile)))
637
- return emitError (loc) << " failed to verify transform module" ;
638
- }
639
-
640
- // Assemble list of library files.
641
- SmallVector<std::string> libraryFileNames;
642
- for (const std::string &path : transformLibraryPaths) {
643
- auto loc = FileLineColLoc::get (context, transformFileName, 0 , 0 );
620
+ // / Expands the given list of `paths` to a list of `.mlir` files.
621
+ // /
622
+ // / Each entry in `paths` may either be a regular file, in which case it ends up
623
+ // / in the result list, or a directory, in which case all (regular) `.mlir`
624
+ // / files in that directory are added. Any other file types lead to a failure.
625
+ static LogicalResult
626
+ expandPathsToMLIRFiles (ArrayRef<std::string> &paths, MLIRContext *const context,
627
+ SmallVectorImpl<std::string> &fileNames) {
628
+ for (const std::string &path : paths) {
629
+ auto loc = FileLineColLoc::get (context, path, 0 , 0 );
644
630
645
631
if (llvm::sys::fs::is_regular_file (path)) {
646
632
LLVM_DEBUG (DBGS () << " Adding '" << path << " ' to list of files\n " );
647
- libraryFileNames .push_back (path);
633
+ fileNames .push_back (path);
648
634
continue ;
649
635
}
650
636
@@ -673,14 +659,43 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
673
659
}
674
660
675
661
LLVM_DEBUG (DBGS () << " Adding '" << fileName << " ' to list of files\n " );
676
- libraryFileNames .push_back (fileName);
662
+ fileNames .push_back (fileName);
677
663
}
678
664
679
665
if (ec)
680
666
return emitError (loc) << " error while opening files in '" << path
681
667
<< " ': " << ec.message ();
682
668
}
683
669
670
+ return success ();
671
+ }
672
+
673
+ LogicalResult transform::detail::interpreterBaseInitializeImpl (
674
+ MLIRContext *context, StringRef transformFileName,
675
+ ArrayRef<std::string> transformLibraryPaths,
676
+ std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
677
+ std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
678
+ function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
679
+ moduleBuilder) {
680
+ auto unknownLoc = UnknownLoc::get (context);
681
+
682
+ // Parse module from file.
683
+ OwningOpRef<ModuleOp> moduleFromFile;
684
+ {
685
+ auto loc = FileLineColLoc::get (context, transformFileName, 0 , 0 );
686
+ if (failed (parseTransformModuleFromFile (context, transformFileName,
687
+ moduleFromFile)))
688
+ return emitError (loc) << " failed to parse transform module" ;
689
+ if (moduleFromFile && failed (mlir::verify (*moduleFromFile)))
690
+ return emitError (loc) << " failed to verify transform module" ;
691
+ }
692
+
693
+ // Assemble list of library files.
694
+ SmallVector<std::string> libraryFileNames;
695
+ if (failed (expandPathsToMLIRFiles (transformLibraryPaths, context,
696
+ libraryFileNames)))
697
+ return failure ();
698
+
684
699
// Parse modules from library files.
685
700
SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
686
701
for (const std::string &libraryFileName : libraryFileNames) {
0 commit comments