|
17 | 17 | #include "clang/CIR/Dialect/IR/CIRTypes.h" |
18 | 18 | #include "clang/CIR/Interfaces/CIRLoopOpInterface.h" |
19 | 19 | #include "clang/CIR/MissingFeatures.h" |
| 20 | +#include "llvm/ADT/SetOperations.h" |
| 21 | +#include "llvm/ADT/SmallSet.h" |
20 | 22 | #include "llvm/ADT/TypeSwitch.h" |
21 | 23 | #include "llvm/Support/ErrorHandling.h" |
22 | 24 | #include "llvm/Support/LogicalResult.h" |
23 | 25 | #include <numeric> |
24 | 26 | #include <optional> |
25 | | -#include <set> |
26 | 27 |
|
27 | 28 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
28 | 29 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
@@ -2886,38 +2887,41 @@ LogicalResult cir::FuncOp::verify() { |
2886 | 2887 | << "' must have empty body"; |
2887 | 2888 | } |
2888 | 2889 |
|
2889 | | - std::set<llvm::StringRef> labels; |
2890 | | - std::set<llvm::StringRef> gotos; |
2891 | | - std::set<llvm::StringRef> blockAddresses; |
| 2890 | + llvm::SmallSet<llvm::StringRef, 16> labels; |
| 2891 | + llvm::SmallSet<llvm::StringRef, 16> gotos; |
| 2892 | + llvm::SmallSet<llvm::StringRef, 16> blockAddresses; |
2892 | 2893 | bool invalidBlockAddress = false; |
2893 | 2894 | getOperation()->walk([&](mlir::Operation *op) { |
2894 | 2895 | if (auto lab = dyn_cast<cir::LabelOp>(op)) { |
2895 | | - labels.emplace(lab.getLabel()); |
| 2896 | + labels.insert(lab.getLabel()); |
2896 | 2897 | } else if (auto goTo = dyn_cast<cir::GotoOp>(op)) { |
2897 | | - gotos.emplace(goTo.getLabel()); |
| 2898 | + gotos.insert(goTo.getLabel()); |
2898 | 2899 | } else if (auto blkAdd = dyn_cast<cir::BlockAddressOp>(op)) { |
2899 | | - if (blkAdd.getFunc() != getSymName()) |
| 2900 | + if (blkAdd.getFunc() != getSymName()) { |
| 2901 | + // Stop the walk early, no need to continue |
2900 | 2902 | invalidBlockAddress = true; |
2901 | | - blockAddresses.emplace(blkAdd.getLabel()); |
| 2903 | + return mlir::WalkResult::interrupt(); |
| 2904 | + } |
| 2905 | + blockAddresses.insert(blkAdd.getLabel()); |
2902 | 2906 | } |
| 2907 | + return mlir::WalkResult::advance(); |
2903 | 2908 | }); |
2904 | 2909 |
|
2905 | 2910 | if (invalidBlockAddress) |
2906 | 2911 | return emitOpError() << "blockaddress references a different function"; |
2907 | 2912 |
|
2908 | | - { |
2909 | | - std::vector<llvm::StringRef> mismatched; |
2910 | | - std::set_difference(gotos.begin(), gotos.end(), labels.begin(), |
2911 | | - labels.end(), std::back_inserter(mismatched)); |
| 2913 | + llvm::SmallSet<llvm::StringRef, 16> mismatched; |
| 2914 | + if (!labels.empty() || !gotos.empty()) { |
| 2915 | + mismatched = llvm::set_difference(gotos, labels); |
2912 | 2916 |
|
2913 | 2917 | if (!mismatched.empty()) |
2914 | 2918 | return emitOpError() << "goto/label mismatch"; |
2915 | 2919 | } |
2916 | | - { |
2917 | | - std::vector<llvm::StringRef> mismatched; |
2918 | | - std::set_difference(blockAddresses.begin(), blockAddresses.end(), |
2919 | | - labels.begin(), labels.end(), |
2920 | | - std::back_inserter(mismatched)); |
| 2920 | + |
| 2921 | + mismatched.clear(); |
| 2922 | + |
| 2923 | + if (!labels.empty() || !blockAddresses.empty()) { |
| 2924 | + mismatched = llvm::set_difference(blockAddresses, labels); |
2921 | 2925 |
|
2922 | 2926 | if (!mismatched.empty()) |
2923 | 2927 | return emitOpError() |
|
0 commit comments