diff --git a/llvm/include/llvm/Demangle/Demangle.h b/llvm/include/llvm/Demangle/Demangle.h index fe129603c0785..132e5088b5514 100644 --- a/llvm/include/llvm/Demangle/Demangle.h +++ b/llvm/include/llvm/Demangle/Demangle.h @@ -10,6 +10,7 @@ #define LLVM_DEMANGLE_DEMANGLE_H #include +#include #include #include @@ -54,6 +55,9 @@ enum MSDemangleFlags { char *microsoftDemangle(std::string_view mangled_name, size_t *n_read, int *status, MSDemangleFlags Flags = MSDF_None); +std::optional +getArm64ECInsertionPointInMangledName(std::string_view MangledName); + // Demangles a Rust v0 mangled symbol. char *rustDemangle(std::string_view MangledName); diff --git a/llvm/include/llvm/Demangle/MicrosoftDemangle.h b/llvm/include/llvm/Demangle/MicrosoftDemangle.h index 6891185a28e57..276efa7603690 100644 --- a/llvm/include/llvm/Demangle/MicrosoftDemangle.h +++ b/llvm/include/llvm/Demangle/MicrosoftDemangle.h @@ -9,6 +9,7 @@ #ifndef LLVM_DEMANGLE_MICROSOFTDEMANGLE_H #define LLVM_DEMANGLE_MICROSOFTDEMANGLE_H +#include "llvm/Demangle/Demangle.h" #include "llvm/Demangle/MicrosoftDemangleNodes.h" #include @@ -141,6 +142,9 @@ enum class FunctionIdentifierCodeGroup { Basic, Under, DoubleUnder }; // It has a set of functions to parse mangled symbols into Type instances. // It also has a set of functions to convert Type instances to strings. class Demangler { + friend std::optional + llvm::getArm64ECInsertionPointInMangledName(std::string_view MangledName); + public: Demangler() = default; virtual ~Demangler() = default; diff --git a/llvm/include/llvm/IR/Mangler.h b/llvm/include/llvm/IR/Mangler.h index 3c3f0c6dce80f..6c8ebf5f072f2 100644 --- a/llvm/include/llvm/IR/Mangler.h +++ b/llvm/include/llvm/IR/Mangler.h @@ -64,7 +64,7 @@ std::optional getArm64ECDemangledFunctionName(StringRef Name); /// Check if an ARM64EC function name is mangled. bool inline isArm64ECMangledFunctionName(StringRef Name) { return Name[0] == '#' || - (Name[0] == '?' && Name.find("$$h") != StringRef::npos); + (Name[0] == '?' && Name.find("@$$h") != StringRef::npos); } } // End llvm namespace diff --git a/llvm/lib/Demangle/MicrosoftDemangle.cpp b/llvm/lib/Demangle/MicrosoftDemangle.cpp index aa65f3be29da7..6be8b0fe73996 100644 --- a/llvm/lib/Demangle/MicrosoftDemangle.cpp +++ b/llvm/lib/Demangle/MicrosoftDemangle.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -2428,6 +2429,24 @@ void Demangler::dumpBackReferences() { std::printf("\n"); } +std::optional +llvm::getArm64ECInsertionPointInMangledName(std::string_view MangledName) { + std::string_view ProcessedName{MangledName}; + + // We only support this for MSVC-style C++ symbols. + if (!consumeFront(ProcessedName, '?')) + return std::nullopt; + + // The insertion point is just after the name of the symbol, so parse that to + // remove it from the processed name. + Demangler D; + D.demangleFullyQualifiedSymbolName(ProcessedName); + if (D.Error) + return std::nullopt; + + return MangledName.length() - ProcessedName.length(); +} + char *llvm::microsoftDemangle(std::string_view MangledName, size_t *NMangled, int *Status, MSDemangleFlags Flags) { Demangler D; diff --git a/llvm/lib/IR/Mangler.cpp b/llvm/lib/IR/Mangler.cpp index 15a4debf191a5..3b9c00cf993f3 100644 --- a/llvm/lib/IR/Mangler.cpp +++ b/llvm/lib/IR/Mangler.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Twine.h" +#include "llvm/Demangle/Demangle.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -299,21 +300,17 @@ std::optional llvm::getArm64ECMangledFunctionName(StringRef Name) { return std::optional(("#" + Name).str()); } - // Insert the ARM64EC "$$h" tag after the mangled function name. + // If the name contains $$h, then it is already mangled. if (Name.contains("$$h")) return std::nullopt; - size_t InsertIdx = Name.find("@@"); - size_t ThreeAtSignsIdx = Name.find("@@@"); - if (InsertIdx != std::string::npos && InsertIdx != ThreeAtSignsIdx) { - InsertIdx += 2; - } else { - InsertIdx = Name.find("@"); - if (InsertIdx != std::string::npos) - InsertIdx++; - } + + // Ask the demangler where we should insert "$$h". + auto InsertIdx = getArm64ECInsertionPointInMangledName(Name); + if (!InsertIdx) + return std::nullopt; return std::optional( - (Name.substr(0, InsertIdx) + "$$h" + Name.substr(InsertIdx)).str()); + (Name.substr(0, *InsertIdx) + "$$h" + Name.substr(*InsertIdx)).str()); } std::optional diff --git a/llvm/unittests/IR/ManglerTest.cpp b/llvm/unittests/IR/ManglerTest.cpp index 5ac784b7e89ac..a2b4e81690310 100644 --- a/llvm/unittests/IR/ManglerTest.cpp +++ b/llvm/unittests/IR/ManglerTest.cpp @@ -172,4 +172,81 @@ TEST(ManglerTest, GOFF) { "L#foo"); } +TEST(ManglerTest, Arm64EC) { + constexpr std::string_view Arm64ECNames[] = { + // Basic C name. + "#Foo", + + // Basic C++ name. + "?foo@@$$hYAHXZ", + + // Regression test: https://github.com/llvm/llvm-project/issues/115231 + "?GetValue@?$Wrapper@UA@@@@$$hQEBAHXZ", + + // Symbols from: + // ``` + // namespace A::B::C::D { + // struct Base { + // virtual int f() { return 0; } + // }; + // } + // struct Derived : public A::B::C::D::Base { + // virtual int f() override { return 1; } + // }; + // A::B::C::D::Base* MakeObj() { return new Derived(); } + // ``` + // void * __cdecl operator new(unsigned __int64) + "??2@$$hYAPEAX_K@Z", + // public: virtual int __cdecl A::B::C::D::Base::f(void) + "?f@Base@D@C@B@A@@$$hUEAAHXZ", + // public: __cdecl A::B::C::D::Base::Base(void) + "??0Base@D@C@B@A@@$$hQEAA@XZ", + // public: virtual int __cdecl Derived::f(void) + "?f@Derived@@$$hUEAAHXZ", + // public: __cdecl Derived::Derived(void) + "??0Derived@@$$hQEAA@XZ", + // struct A::B::C::D::Base * __cdecl MakeObj(void) + "?MakeObj@@$$hYAPEAUBase@D@C@B@A@@XZ", + + // Symbols from: + // ``` + // template struct WW { struct Z{}; }; + // template struct Wrapper { + // int GetValue(typename WW::Z) const; + // }; + // struct A { }; + // template int Wrapper::GetValue(typename WW::Z) const + // { return 3; } + // template class Wrapper; + // ``` + // public: int __cdecl Wrapper::GetValue(struct WW::Z)const + "?GetValue@?$Wrapper@UA@@@@$$hQEBAHUZ@?$WW@UA@@@@@Z", + }; + + for (const auto &Arm64ECName : Arm64ECNames) { + // Check that this is a mangled name. + EXPECT_TRUE(isArm64ECMangledFunctionName(Arm64ECName)) + << "Test case: " << Arm64ECName; + // Refuse to mangle it again. + EXPECT_FALSE(getArm64ECMangledFunctionName(Arm64ECName).has_value()) + << "Test case: " << Arm64ECName; + + // Demangle. + auto Arm64Name = getArm64ECDemangledFunctionName(Arm64ECName); + EXPECT_TRUE(Arm64Name.has_value()) << "Test case: " << Arm64ECName; + // Check that it is not mangled. + EXPECT_FALSE(isArm64ECMangledFunctionName(Arm64Name.value())) + << "Test case: " << Arm64ECName; + // Refuse to demangle it again. + EXPECT_FALSE(getArm64ECDemangledFunctionName(Arm64Name.value()).has_value()) + << "Test case: " << Arm64ECName; + + // Round-trip. + auto RoundTripArm64ECName = + getArm64ECMangledFunctionName(Arm64Name.value()); + EXPECT_EQ(RoundTripArm64ECName, Arm64ECName); + } +} + } // end anonymous namespace