2121#include " llvm/ADT/StringMap.h"
2222#include " llvm/ADT/StringRef.h"
2323#include " llvm/ADT/Twine.h"
24+ #include " llvm/ADT/TypeSwitch.h"
2425#include " llvm/Support/ErrorHandling.h"
2526#include " llvm/Support/Path.h"
2627
@@ -56,16 +57,29 @@ template <typename TargetOp> class StdRecognizer {
5657 return dyn_cast<FuncOp>(global);
5758 }
5859
59- static bool isStdVector (const clang::CXXRecordDecl *RD) {
60- if (!RD || !RD->getDeclContext ()->isStdNamespace ())
61- return false ;
60+ static std::optional<StringRef>
61+ getRecordName (const clang::CXXRecordDecl *rd) {
62+ if (!rd || !rd->getDeclContext ()->isStdNamespace ())
63+ return std::nullopt ;
6264
63- if (RD->getDeclName ().isIdentifier ()) {
64- StringRef Name = RD->getName ();
65- return Name == " vector" ;
66- }
65+ if (rd->getDeclName ().isIdentifier ())
66+ return rd->getName ();
6767
68- return false ;
68+ return std::nullopt ;
69+ }
70+
71+ static std::optional<std::string>
72+ resolveSpecialMember (mlir::Attribute specialMember) {
73+ return TypeSwitch<Attribute, std::optional<std::string>>(specialMember)
74+ .Case <CXXCtorAttr, CXXDtorAttr>(
75+ [](auto attr) -> std::optional<std::string> {
76+ if (!attr.getRecordDecl ())
77+ return std::nullopt ;
78+ if (auto recordName = getRecordName (*attr.getRecordDecl ()))
79+ return recordName->str () + " _" + attr.getMnemonic ().str ();
80+ return std::nullopt ;
81+ })
82+ .Default ([](Attribute) { return std::nullopt ; });
6983 }
7084
7185 static bool raise (mlir::ModuleOp theModule, CallOp call,
@@ -74,21 +88,12 @@ template <typename TargetOp> class StdRecognizer {
7488 if (call.getNumOperands () != numArgs)
7589 return false ;
7690
77- llvm::StringRef name = *call.getCallee ();
78- auto calleeFunc = getCalleeFromSymbol (theModule, name);
79-
8091 llvm::StringRef stdFuncName = TargetOp::getFunctionName ();
92+ auto calleeFunc = getCalleeFromSymbol (theModule, *call.getCallee ());
8193
8294 if (auto specialMember = calleeFunc.getCxxSpecialMemberAttr ()) {
83- auto matches =
84- (stdFuncName == " vector_ctor" && isa<CXXCtorAttr>(specialMember)) ||
85- (stdFuncName == " vector_dtor" && isa<CXXDtorAttr>(specialMember));
86- if (!matches)
87- return false ;
88-
89- auto recordDeclAttr = call.getAstRecordAttr ();
90- if (!recordDeclAttr ||
91- !isStdVector (cast<clang::CXXRecordDecl>(recordDeclAttr.getRawDecl ())))
95+ auto resolved = resolveSpecialMember (specialMember);
96+ if (!resolved || *resolved != stdFuncName.str ())
9297 return false ;
9398 } else {
9499 auto callExprAttr = call.getAstAttr ();
@@ -97,12 +102,12 @@ template <typename TargetOp> class StdRecognizer {
97102
98103 if (!checkArguments (call.getArgOperands ()))
99104 return false ;
100-
101- if (remark)
102- mlir::emitRemark (call.getLoc ())
103- << " found call to std::" << stdFuncName << " ()" ;
104105 }
105106
107+ if (remark)
108+ mlir::emitRemark (call.getLoc ())
109+ << " found call to std::" << stdFuncName << " ()" ;
110+
106111 CIRBaseBuilderTy builder (context);
107112 builder.setInsertionPointAfter (call.getOperation ());
108113 TargetOp op = buildCall (builder, call, std::make_index_sequence<numArgs>());
0 commit comments