diff --git a/include/clang/Interpreter/CppInterOp.h b/include/clang/Interpreter/CppInterOp.h index c7c3d51ee..79a0df75f 100644 --- a/include/clang/Interpreter/CppInterOp.h +++ b/include/clang/Interpreter/CppInterOp.h @@ -676,6 +676,15 @@ namespace Cpp { const char* m_IntegralValue; TemplateArgInfo(TCppScope_t type, const char* integral_value = nullptr) : m_Type(type), m_IntegralValue(integral_value) {} + friend bool operator==(const TemplateArgInfo& lhs, + const TemplateArgInfo& rhs) { + return (lhs.m_Type == rhs.m_Type && + lhs.m_IntegralValue == rhs.m_IntegralValue); + } + friend bool operator!=(const TemplateArgInfo& lhs, + const TemplateArgInfo& rhs) { + return !(lhs == rhs); + } }; /// Builds a template instantiation for a given templated declaration. /// Offers a single interface for instantiation of class, function and diff --git a/lib/Interpreter/CppInterOp.cpp b/lib/Interpreter/CppInterOp.cpp index 51f26a2b3..431cd5971 100755 --- a/lib/Interpreter/CppInterOp.cpp +++ b/lib/Interpreter/CppInterOp.cpp @@ -14,6 +14,7 @@ #include "clang/AST/CXXInheritance.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclCXX.h" +#include "clang/AST/DeclTemplate.h" #include "clang/AST/GlobalDecl.h" #include "clang/AST/Mangle.h" #include "clang/AST/QualTypeNames.h" @@ -42,6 +43,7 @@ #include #include #include +#include // Stream redirect. #ifdef _WIN32 @@ -1026,11 +1028,103 @@ namespace Cpp { funcs.push_back(Found); } + namespace { + inline void + collectUniqueTemplateArgs(const std::vector& templ_types, + std::vector& result) { + std::unique_copy(templ_types.begin(), templ_types.end(), + std::back_inserter(result)); + } + bool + IsTemplateFunctionGoodMatch(const FunctionTemplateDecl* FTD, + const std::vector& arg_types, + std::vector& templ_types) { + const FunctionDecl* F = FTD->getTemplatedDecl(); + clang::TemplateParameterList* tpl = FTD->getTemplateParameters(); + + if (arg_types.size() != F->getNumParams()) + return false; + + for (size_t i = 0; i < arg_types.size(); i++) { + QualType fn_arg_type = F->getParamDecl(i)->getType(); + QualType arg_type = QualType::getFromOpaquePtr(arg_types[i].m_Type); + + // dereference + if (fn_arg_type->isReferenceType()) + fn_arg_type = fn_arg_type.getNonReferenceType(); + if (arg_type->isReferenceType()) + arg_type = arg_type.getNonReferenceType(); + + fn_arg_type = fn_arg_type.getCanonicalType(); + arg_type = arg_type.getCanonicalType(); + + // matching parameter and argument types + // resolving parameter + const auto* fn_TST = + fn_arg_type->getAs(); + const TemplateDecl* fn_TD = nullptr; + if (fn_TST) + fn_TD = fn_TST->getTemplateName().getAsTemplateDecl(); + + // resolving argument + const auto* arg_RT = arg_type->getAs(); + ClassTemplateSpecializationDecl* arg_CTSD = nullptr; + if (arg_RT) + arg_CTSD = llvm::dyn_cast( + arg_RT->getDecl()); + + if ((!arg_CTSD || !fn_TD) && (arg_CTSD || fn_TD)) + return false; + + // check if types match + if (arg_CTSD) { + auto* arg_D = arg_CTSD->getSpecializedTemplate()->getCanonicalDecl(); + if (arg_D != fn_TD->getCanonicalDecl()) + return false; + if (templ_types.size() < tpl->size()) { + Cpp::GetClassTemplateInstantiationArgs(arg_CTSD, templ_types); + break; + } + } else if (templ_types.size() < tpl->size()) { + templ_types.push_back(arg_types[i]); + } + } + return true; + } + } // namespace + TCppFunction_t BestTemplateFunctionMatch(const std::vector& candidates, const std::vector& explicit_types, const std::vector& arg_types) { + /* + Try matching function with templated class as arguments first + Example: + + template + struct A { T value; }; + + template + void somefunc(A arg); // overload 1 + + template + void somefunc(T arg); // overload 2 + + somefunc(A()); // should call overload 1; resolve this first + somefunc(3); // should call overload 2 + */ + for (const auto& candidate : candidates) { + std::vector templ_types; + auto* TFD = static_cast(candidate); + if (IsTemplateFunctionGoodMatch(TFD, arg_types, templ_types)) { + TCppFunction_t instantiated = InstantiateTemplate( + candidate, templ_types.data(), templ_types.size()); + if (instantiated) + return instantiated; + } + } + for (const auto& candidate : candidates) { auto* TFD = (FunctionTemplateDecl*)candidate; clang::TemplateParameterList* tpl = TFD->getTemplateParameters(); @@ -1060,9 +1154,19 @@ namespace Cpp { if (instantiated) return instantiated; + std::vector unique_arg_types; + collectUniqueTemplateArgs(arg_types, unique_arg_types); + instantiated = InstantiateTemplate(candidate, unique_arg_types.data(), + unique_arg_types.size()); + if (instantiated) + return instantiated; + // Force the instantiation with template params in case of no args // maybe steer instantiation better with arg set returned from // TemplateProxy? + if (explicit_types.empty()) + continue; + instantiated = InstantiateTemplate(candidate, explicit_types.data(), explicit_types.size()); if (instantiated) diff --git a/unittests/CppInterOp/FunctionReflectionTest.cpp b/unittests/CppInterOp/FunctionReflectionTest.cpp index 1c576fbe7..f4407cd8b 100644 --- a/unittests/CppInterOp/FunctionReflectionTest.cpp +++ b/unittests/CppInterOp/FunctionReflectionTest.cpp @@ -653,6 +653,68 @@ TEST(FunctionReflectionTest, BestTemplateFunctionMatch) { "template<> long MyTemplatedMethodClass::get_size(float &)"); } +TEST(FunctionReflectionTest, BestTemplateFunctionMatch2) { + std::vector Decls; + std::string code = R"( + template + struct A { T value; }; + + A a; + + template + void somefunc(A arg) {} + + template + void somefunc(T arg) {} + + template + void somefunc(A arg1, A arg2) {} + + template + void somefunc(T arg1, T arg2) {} + )"; + + GetAllTopLevelDecls(code, Decls); + std::vector candidates; + + for (auto decl : Decls) + if (Cpp::IsTemplatedFunction(decl)) + candidates.push_back((Cpp::TCppFunction_t)decl); + + EXPECT_EQ(candidates.size(), 4); + + ASTContext& C = Interp->getCI()->getASTContext(); + + std::vector args1 = {C.IntTy.getAsOpaquePtr()}; + std::vector args2 = { + Cpp::GetVariableType(Cpp::GetNamed("a"))}; + std::vector args3 = {C.IntTy.getAsOpaquePtr(), + C.IntTy.getAsOpaquePtr()}; + std::vector args4 = { + Cpp::GetVariableType(Cpp::GetNamed("a")), + Cpp::GetVariableType(Cpp::GetNamed("a"))}; + + std::vector explicit_args; + + Cpp::TCppFunction_t func1 = + Cpp::BestTemplateFunctionMatch(candidates, explicit_args, args1); + Cpp::TCppFunction_t func2 = + Cpp::BestTemplateFunctionMatch(candidates, explicit_args, args2); + Cpp::TCppFunction_t func3 = + Cpp::BestTemplateFunctionMatch(candidates, explicit_args, args3); + Cpp::TCppFunction_t func4 = + Cpp::BestTemplateFunctionMatch(candidates, explicit_args, args4); + + EXPECT_EQ(Cpp::GetFunctionSignature(func1), + "template<> void somefunc(int arg)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func2), + "template<> void somefunc(A arg)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func3), + "template<> void somefunc(int arg1, int arg2)"); + EXPECT_EQ(Cpp::GetFunctionSignature(func4), + "template<> void somefunc(A arg1, A arg2)"); +} + TEST(FunctionReflectionTest, IsPublicMethod) { std::vector Decls, SubDecls; std::string code = R"(