diff --git a/include/clang/Interpreter/CppInterOp.h b/include/clang/Interpreter/CppInterOp.h index c7c3d51ee..c0cb234aa 100644 --- a/include/clang/Interpreter/CppInterOp.h +++ b/include/clang/Interpreter/CppInterOp.h @@ -210,6 +210,9 @@ namespace Cpp { /// Checks if the scope is a class or not. CPPINTEROP_API bool IsClass(TCppScope_t scope); + /// Checks if the scope is a function. + CPPINTEROP_API bool IsFunction(TCppScope_t scope); + /// Checks if the type is a function pointer. CPPINTEROP_API bool IsFunctionPointerType(TCppType_t type); @@ -706,17 +709,16 @@ namespace Cpp { CPPINTEROP_API TCppFunction_t InstantiateTemplateFunctionFromString(const char* function_template); - /// Finds best template match based on explicit template parameters and - /// argument types + /// Finds best overload match based on explicit template parameters (if any) + /// and argument types. /// - ///\param[in] candidates - Vector of suitable candidates that come under the - /// parent scope and have the same name (obtained using - /// GetClassTemplatedMethods) + ///\param[in] candidates - vector of overloads that come under the + /// parent scope and have the same name ///\param[in] explicit_types - set of expicitly instantiated template types ///\param[in] arg_types - set of argument types ///\returns Instantiated function pointer CPPINTEROP_API TCppFunction_t - BestTemplateFunctionMatch(const std::vector& candidates, + BestOverloadFunctionMatch(const std::vector& candidates, const std::vector& explicit_types, const std::vector& arg_types); diff --git a/lib/Interpreter/CppInterOp.cpp b/lib/Interpreter/CppInterOp.cpp index 51f26a2b3..e4ef404a7 100755 --- a/lib/Interpreter/CppInterOp.cpp +++ b/lib/Interpreter/CppInterOp.cpp @@ -13,16 +13,29 @@ #include "clang/AST/CXXInheritance.h" #include "clang/AST/Decl.h" +#include "clang/AST/DeclAccessPair.h" +#include "clang/AST/DeclBase.h" #include "clang/AST/DeclCXX.h" +#include "clang/AST/DeclarationName.h" +#include "clang/AST/Expr.h" +#include "clang/AST/ExprCXX.h" #include "clang/AST/GlobalDecl.h" #include "clang/AST/Mangle.h" +#include "clang/AST/NestedNameSpecifier.h" #include "clang/AST/QualTypeNames.h" #include "clang/AST/RecordLayout.h" +#include "clang/AST/Stmt.h" +#include "clang/AST/Type.h" #include "clang/Basic/DiagnosticSema.h" #include "clang/Basic/Linkage.h" +#include "clang/Basic/OperatorKinds.h" +#include "clang/Basic/SourceLocation.h" +#include "clang/Basic/Specifiers.h" #include "clang/Basic/Version.h" #include "clang/Frontend/CompilerInstance.h" #include "clang/Sema/Lookup.h" +#include "clang/Sema/Overload.h" +#include "clang/Sema/Ownership.h" #include "clang/Sema/Sema.h" #if CLANG_VERSION_MAJOR >= 19 #include "clang/Sema/Redeclaration.h" @@ -210,6 +223,11 @@ namespace Cpp { return isa(D); } + bool IsFunction(TCppScope_t scope) { + Decl* D = static_cast(scope); + return isa(D); + } + bool IsFunctionPointerType(TCppType_t type) { QualType QT = QualType::getFromOpaquePtr(type); return QT->isFunctionPointerType(); @@ -877,8 +895,19 @@ namespace Cpp { TCppType_t GetFunctionReturnType(TCppFunction_t func) { auto *D = (clang::Decl *) func; - if (auto* FD = llvm::dyn_cast_or_null(D)) - return FD->getReturnType().getAsOpaquePtr(); + if (auto* FD = llvm::dyn_cast_or_null(D)) { + QualType Type = FD->getReturnType(); + if (Type->isUndeducedAutoType() && IsTemplatedFunction(FD) && + !FD->isDefined()) { +#ifdef CPPINTEROP_USE_CLING + cling::Interpreter::PushTransactionRAII RAII(&getInterp()); +#endif + getSema().InstantiateFunctionDefinition(SourceLocation(), FD, true, + true); + Type = FD->getReturnType(); + } + return Type.getAsOpaquePtr(); + } if (auto* FD = llvm::dyn_cast_or_null(D)) return (FD->getTemplatedDecl())->getReturnType().getAsOpaquePtr(); @@ -1026,62 +1055,89 @@ namespace Cpp { funcs.push_back(Found); } + // Adapted from inner workings of Sema::BuildCallExpr TCppFunction_t - BestTemplateFunctionMatch(const std::vector& candidates, + BestOverloadFunctionMatch(const std::vector& candidates, const std::vector& explicit_types, const std::vector& arg_types) { + auto& S = getSema(); + auto& C = S.getASTContext(); - for (const auto& candidate : candidates) { - auto* TFD = (FunctionTemplateDecl*)candidate; - clang::TemplateParameterList* tpl = TFD->getTemplateParameters(); +#ifdef CPPINTEROP_USE_CLING + cling::Interpreter::PushTransactionRAII RAII(&getInterp()); +#endif - // template parameter size does not match - if (tpl->size() < explicit_types.size()) - continue; + // The overload resolution interfaces in Sema require a list of expressions. + // However, unlike handwritten C++, we do not always have a expression. + // Here we synthesize a placeholder expression to be able to use + // Sema::AddOverloadCandidate. Made up expressions are fine because the + // interface uses the list size and the expression types. + struct WrapperExpr : public OpaqueValueExpr { + WrapperExpr() : OpaqueValueExpr(clang::Stmt::EmptyShell()) {} + }; + auto* Exprs = new WrapperExpr[arg_types.size()]; + llvm::SmallVector Args; + Args.reserve(arg_types.size()); + size_t idx = 0; + for (auto i : arg_types) { + QualType Type = QualType::getFromOpaquePtr(i.m_Type); + ExprValueKind ExprKind = ExprValueKind::VK_PRValue; + if (Type->isReferenceType()) + ExprKind = ExprValueKind::VK_LValue; + + new (&Exprs[idx]) OpaqueValueExpr(SourceLocation::getFromRawEncoding(1), + Type.getNonReferenceType(), ExprKind); + Args.push_back(&Exprs[idx]); + ++idx; + } - // right now uninstantiated functions give template typenames instead of - // actual types. We make this match solely based on count + // Create a list of template arguments. + llvm::SmallVector TemplateArgs; + TemplateArgs.reserve(explicit_types.size()); + for (auto explicit_type : explicit_types) { + QualType ArgTy = QualType::getFromOpaquePtr(explicit_type.m_Type); + if (explicit_type.m_IntegralValue) { + // We have a non-type template parameter. Create an integral value from + // the string representation. + auto Res = llvm::APSInt(explicit_type.m_IntegralValue); + Res = Res.extOrTrunc(C.getIntWidth(ArgTy)); + TemplateArgs.push_back(TemplateArgument(C, Res, ArgTy)); + } else { + TemplateArgs.push_back(ArgTy); + } + } - const FunctionDecl* func = TFD->getTemplatedDecl(); + TemplateArgumentListInfo ExplicitTemplateArgs{}; + for (auto TA : TemplateArgs) + ExplicitTemplateArgs.addArgument( + S.getTrivialTemplateArgumentLoc(TA, QualType(), SourceLocation())); + + OverloadCandidateSet Overloads( + SourceLocation(), OverloadCandidateSet::CandidateSetKind::CSK_Normal); + + for (void* i : candidates) { + Decl* D = static_cast(i); + if (auto* FD = dyn_cast(D)) { + S.AddOverloadCandidate(FD, DeclAccessPair::make(FD, FD->getAccess()), + Args, Overloads); + } else if (auto* FTD = dyn_cast(D)) { + // AddTemplateOverloadCandidate is causing a memory leak + // It is a known bug at clang + // call stack: AddTemplateOverloadCandidate -> MakeDeductionFailureInfo + // source: + // https://github.com/llvm/llvm-project/blob/release/19.x/clang/lib/Sema/SemaOverload.cpp#L731-L756 + S.AddTemplateOverloadCandidate( + FTD, DeclAccessPair::make(FTD, FTD->getAccess()), + &ExplicitTemplateArgs, Args, Overloads); + } + } -#ifdef CPPINTEROP_USE_CLING - if (func->getNumParams() > arg_types.size()) - continue; -#else // CLANG_REPL - if (func->getMinRequiredArguments() > arg_types.size()) - continue; -#endif + OverloadCandidateSet::iterator Best; + Overloads.BestViableFunction(S, SourceLocation(), Best); - // TODO(aaronj0) : first score based on the type similarity before forcing - // instantiation. - - TCppFunction_t instantiated = - InstantiateTemplate(candidate, arg_types.data(), arg_types.size()); - if (instantiated) - return instantiated; - - // Force the instantiation with template params in case of no args - // maybe steer instantiation better with arg set returned from - // TemplateProxy? - instantiated = InstantiateTemplate(candidate, explicit_types.data(), - explicit_types.size()); - if (instantiated) - return instantiated; - - // join explicit and arg_types - std::vector total_arg_set; - total_arg_set.reserve(explicit_types.size() + arg_types.size()); - total_arg_set.insert(total_arg_set.end(), explicit_types.begin(), - explicit_types.end()); - total_arg_set.insert(total_arg_set.end(), arg_types.begin(), - arg_types.end()); - - instantiated = InstantiateTemplate(candidate, total_arg_set.data(), - total_arg_set.size()); - if (instantiated) - return instantiated; - } - return nullptr; + FunctionDecl* Result = Best != Overloads.end() ? Best->Function : nullptr; + delete[] Exprs; + return Result; } // Gets the AccessSpecifier of the function and checks if it is equal to diff --git a/unittests/CppInterOp/FunctionReflectionTest.cpp b/unittests/CppInterOp/FunctionReflectionTest.cpp index 1c576fbe7..feb028cea 100644 --- a/unittests/CppInterOp/FunctionReflectionTest.cpp +++ b/unittests/CppInterOp/FunctionReflectionTest.cpp @@ -313,6 +313,11 @@ TEST(FunctionReflectionTest, GetFunctionReturnType) { return sizeof(A) + i; } }; + + template struct RTTest_TemplatedList {}; + template auto rttest_make_tlist(T ... args) { + return RTTest_TemplatedList{}; + } )"; GetAllTopLevelDecls(code, Decls, true); @@ -348,6 +353,16 @@ TEST(FunctionReflectionTest, GetFunctionReturnType) { EXPECT_EQ( Cpp::GetTypeAsString(Cpp::GetFunctionReturnType(TemplateSubDecls[3])), "long"); + + ASTContext& C = Interp->getCI()->getASTContext(); + std::vector args = {C.IntTy.getAsOpaquePtr(), + C.DoubleTy.getAsOpaquePtr()}; + std::vector explicit_args; + std::vector candidates = {Decls[14]}; + EXPECT_EQ( + Cpp::GetTypeAsString(Cpp::GetFunctionReturnType( + Cpp::BestOverloadFunctionMatch(candidates, explicit_args, args))), + "RTTest_TemplatedList"); } TEST(FunctionReflectionTest, GetFunctionNumArgs) { @@ -590,7 +605,7 @@ TEST(FunctionReflectionTest, InstantiateTemplateMethod) { EXPECT_TRUE(TA1.getAsType()->isIntegerType()); } -TEST(FunctionReflectionTest, BestTemplateFunctionMatch) { +TEST(FunctionReflectionTest, BestOverloadFunctionMatch1) { std::vector Decls; std::string code = R"( class MyTemplatedMethodClass { @@ -598,7 +613,8 @@ TEST(FunctionReflectionTest, BestTemplateFunctionMatch) { template long get_size(A&); template long get_size(); template long get_size(A a, B b); - template long add_size(float a); + template long get_size(float a); + template long get_size(T a); }; template @@ -612,7 +628,7 @@ TEST(FunctionReflectionTest, BestTemplateFunctionMatch) { } template - long MyTemplatedMethodClass::add_size(float a) { + long MyTemplatedMethodClass::get_size(float a) { return sizeof(A) + long(a); } @@ -620,6 +636,11 @@ TEST(FunctionReflectionTest, BestTemplateFunctionMatch) { long MyTemplatedMethodClass::get_size(A a, B b) { return sizeof(A) + sizeof(B); } + + template + long MyTemplatedMethodClass::get_size(T a) { + return N + sizeof(T) + a; + } )"; GetAllTopLevelDecls(code, Decls); @@ -631,17 +652,26 @@ TEST(FunctionReflectionTest, BestTemplateFunctionMatch) { ASTContext& C = Interp->getCI()->getASTContext(); std::vector args0; - std::vector args1 = {C.IntTy.getAsOpaquePtr()}; + std::vector args1 = { + C.getLValueReferenceType(C.IntTy).getAsOpaquePtr()}; std::vector args2 = {C.CharTy.getAsOpaquePtr(), C.FloatTy.getAsOpaquePtr()}; std::vector args3 = {C.FloatTy.getAsOpaquePtr()}; std::vector explicit_args0; std::vector explicit_args1 = {C.IntTy.getAsOpaquePtr()}; - - Cpp::TCppFunction_t func1 = Cpp::BestTemplateFunctionMatch(candidates, explicit_args0, args1); - Cpp::TCppFunction_t func2 = Cpp::BestTemplateFunctionMatch(candidates, explicit_args1, args0); - Cpp::TCppFunction_t func3 = Cpp::BestTemplateFunctionMatch(candidates, explicit_args0, args2); - Cpp::TCppFunction_t func4 = Cpp::BestTemplateFunctionMatch(candidates, explicit_args1, args3); + std::vector explicit_args2 = { + {C.IntTy.getAsOpaquePtr(), "1"}, C.IntTy.getAsOpaquePtr()}; + + Cpp::TCppFunction_t func1 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args0, args1); + Cpp::TCppFunction_t func2 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args1, args0); + Cpp::TCppFunction_t func3 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args0, args2); + Cpp::TCppFunction_t func4 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args1, args3); + Cpp::TCppFunction_t func5 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args2, args3); EXPECT_EQ(Cpp::GetFunctionSignature(func1), "template<> long MyTemplatedMethodClass::get_size(int &)"); @@ -650,7 +680,228 @@ TEST(FunctionReflectionTest, BestTemplateFunctionMatch) { EXPECT_EQ(Cpp::GetFunctionSignature(func3), "template<> long MyTemplatedMethodClass::get_size(char a, float b)"); EXPECT_EQ(Cpp::GetFunctionSignature(func4), - "template<> long MyTemplatedMethodClass::get_size(float &)"); + "template<> long MyTemplatedMethodClass::get_size(float a)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func5), + "template<> long MyTemplatedMethodClass::get_size<1, int>(int a)"); +} + +TEST(FunctionReflectionTest, BestOverloadFunctionMatch2) { + std::vector Decls; + std::string code = R"( + template + struct A { T value; }; + + A a; + + template + void somefunc(A arg) {} + + template + void somefunc(T arg) {} + + template + void somefunc(A arg1, A arg2) {} + + template + void somefunc(T arg1, T arg2) {} + + void somefunc(int arg1, double arg2) {} + )"; + + GetAllTopLevelDecls(code, Decls); + std::vector candidates; + + for (auto decl : Decls) + if (Cpp::IsFunction(decl) || Cpp::IsTemplatedFunction(decl)) + candidates.push_back((Cpp::TCppFunction_t)decl); + + EXPECT_EQ(candidates.size(), 5); + + ASTContext& C = Interp->getCI()->getASTContext(); + + std::vector args1 = {C.IntTy.getAsOpaquePtr()}; + std::vector args2 = { + Cpp::GetVariableType(Cpp::GetNamed("a"))}; + std::vector args3 = {C.IntTy.getAsOpaquePtr(), + C.IntTy.getAsOpaquePtr()}; + std::vector args4 = { + Cpp::GetVariableType(Cpp::GetNamed("a")), + Cpp::GetVariableType(Cpp::GetNamed("a"))}; + std::vector args5 = {C.IntTy.getAsOpaquePtr(), + C.DoubleTy.getAsOpaquePtr()}; + + std::vector explicit_args; + + Cpp::TCppFunction_t func1 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args, args1); + Cpp::TCppFunction_t func2 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args, args2); + Cpp::TCppFunction_t func3 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args, args3); + Cpp::TCppFunction_t func4 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args, args4); + Cpp::TCppFunction_t func5 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args, args5); + + EXPECT_EQ(Cpp::GetFunctionSignature(func1), + "template<> void somefunc(int arg)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func2), + "template<> void somefunc(A arg)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func3), + "template<> void somefunc(int arg1, int arg2)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func4), + "template<> void somefunc(A arg1, A arg2)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func5), + "void somefunc(int arg1, double arg2)"); +} + +TEST(FunctionReflectionTest, BestOverloadFunctionMatch3) { + std::vector Decls; + std::string code = R"( + template + struct A { + T value; + + template + A operator-(A rhs) { + return A{value - rhs.value}; + } + }; + + A a; + + template + A operator+(A lhs, A rhs) { + return A{lhs.value + rhs.value}; + } + + template + A operator+(A lhs, int rhs) { + return A{lhs.value + rhs}; + } + )"; + + GetAllTopLevelDecls(code, Decls); + std::vector candidates; + + for (auto decl : Decls) + if (Cpp::IsTemplatedFunction(decl)) + candidates.push_back((Cpp::TCppFunction_t)decl); + + EXPECT_EQ(candidates.size(), 2); + + ASTContext& C = Interp->getCI()->getASTContext(); + + std::vector args1 = { + Cpp::GetVariableType(Cpp::GetNamed("a")), + Cpp::GetVariableType(Cpp::GetNamed("a"))}; + std::vector args2 = { + Cpp::GetVariableType(Cpp::GetNamed("a")), C.IntTy.getAsOpaquePtr()}; + std::vector args3 = { + Cpp::GetVariableType(Cpp::GetNamed("a")), C.DoubleTy.getAsOpaquePtr()}; + + std::vector explicit_args; + + Cpp::TCppFunction_t func1 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args, args1); + Cpp::TCppFunction_t func2 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args, args2); + Cpp::TCppFunction_t func3 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args, args3); + + candidates.clear(); + Cpp::GetOperator( + Cpp::GetScopeFromType(Cpp::GetVariableType(Cpp::GetNamed("a"))), + Cpp::Operator::OP_Minus, candidates); + + EXPECT_EQ(candidates.size(), 1); + + std::vector args4 = { + Cpp::GetVariableType(Cpp::GetNamed("a"))}; + + Cpp::TCppFunction_t func4 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args, args4); + + EXPECT_EQ(Cpp::GetFunctionSignature(func1), + "template<> A operator+(A lhs, A rhs)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func2), + "template<> A operator+(A lhs, int rhs)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func3), + "template<> A operator+(A lhs, int rhs)"); + + EXPECT_EQ(Cpp::GetFunctionSignature(func4), + "template<> A A::operator-(A rhs)"); +} + +TEST(FunctionReflectionTest, BestOverloadFunctionMatch4) { + std::vector Decls, SubDecls; + std::string code = R"( + template + struct A { T value; }; + + class B { + public: + void fn() {} + template + void fn(T x) {} + template + void fn(A x) {} + template + void fn(A x, A y) {} + }; + + A a; + A b; + )"; + + GetAllTopLevelDecls(code, Decls); + GetAllSubDecls(Decls[1], SubDecls); + std::vector candidates; + for (auto i : SubDecls) { + if ((Cpp::IsFunction(i) || Cpp::IsTemplatedFunction(i)) && + Cpp::GetName(i) == "fn") + candidates.push_back(i); + } + + EXPECT_EQ(candidates.size(), 4); + + ASTContext& C = Interp->getCI()->getASTContext(); + + std::vector args1 = {}; + std::vector args2 = {C.IntTy.getAsOpaquePtr()}; + std::vector args3 = { + Cpp::GetVariableType(Cpp::GetNamed("a"))}; + std::vector args4 = { + Cpp::GetVariableType(Cpp::GetNamed("a")), + Cpp::GetVariableType(Cpp::GetNamed("b"))}; + std::vector args5 = { + Cpp::GetVariableType(Cpp::GetNamed("a")), + Cpp::GetVariableType(Cpp::GetNamed("a"))}; + + std::vector explicit_args1; + std::vector explicit_args2 = {C.IntTy.getAsOpaquePtr(), + C.IntTy.getAsOpaquePtr()}; + + Cpp::TCppFunction_t func1 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args1, args1); + Cpp::TCppFunction_t func2 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args1, args2); + Cpp::TCppFunction_t func3 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args1, args3); + Cpp::TCppFunction_t func4 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args1, args4); + Cpp::TCppFunction_t func5 = + Cpp::BestOverloadFunctionMatch(candidates, explicit_args2, args5); + + EXPECT_EQ(Cpp::GetFunctionSignature(func1), "void B::fn()"); + EXPECT_EQ(Cpp::GetFunctionSignature(func2), + "template<> void B::fn(int x)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func3), + "template<> void B::fn(A x)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func4), + "template<> void B::fn(A x, A y)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func5), + "template<> void B::fn(A x, A y)"); } TEST(FunctionReflectionTest, IsPublicMethod) {