Skip to content

Commit 3a54022

Browse files
[Clang][Sema] Fix comparison of constraint expressions
This diff switches the approach to comparison of constraint expressions to the new one based on template args substitution. It continues the effort to fix our handling of out-of-line definitions of constrained templates. This is a recommit of e3b1083. Differential revision: https://reviews.llvm.org/D146178
1 parent 293b483 commit 3a54022

10 files changed

+487
-63
lines changed

clang/include/clang/AST/DeclTemplate.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2309,9 +2309,15 @@ class ClassTemplateDecl : public RedeclarableTemplateDecl {
23092309
return static_cast<Common *>(RedeclarableTemplateDecl::getCommonPtr());
23102310
}
23112311

2312+
void setCommonPtr(Common *C) {
2313+
RedeclarableTemplateDecl::Common = C;
2314+
}
2315+
23122316
public:
2317+
23132318
friend class ASTDeclReader;
23142319
friend class ASTDeclWriter;
2320+
friend class TemplateDeclInstantiator;
23152321

23162322
/// Load any lazily-loaded specializations from the external source.
23172323
void LoadLazySpecializations() const;

clang/include/clang/Sema/Template.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,21 @@ enum class TemplateSubstitutionKind : char {
232232
/// Replaces the current 'innermost' level with the provided argument list.
233233
/// This is useful for type deduction cases where we need to get the entire
234234
/// list from the AST, but then add the deduced innermost list.
235-
void replaceInnermostTemplateArguments(ArgList Args) {
236-
assert(TemplateArgumentLists.size() > 0 && "Replacing in an empty list?");
237-
TemplateArgumentLists[0].Args = Args;
235+
void replaceInnermostTemplateArguments(Decl *AssociatedDecl, ArgList Args) {
236+
assert((!TemplateArgumentLists.empty() || NumRetainedOuterLevels) &&
237+
"Replacing in an empty list?");
238+
239+
if (!TemplateArgumentLists.empty()) {
240+
assert((TemplateArgumentLists[0].AssociatedDeclAndFinal.getPointer() ||
241+
TemplateArgumentLists[0].AssociatedDeclAndFinal.getPointer() ==
242+
AssociatedDecl) &&
243+
"Trying to change incorrect declaration?");
244+
TemplateArgumentLists[0].Args = Args;
245+
} else {
246+
--NumRetainedOuterLevels;
247+
TemplateArgumentLists.push_back(
248+
{{AssociatedDecl, /*Final=*/false}, Args});
249+
}
238250
}
239251

240252
/// Add an outermost level that we are not substituting. We have no

clang/lib/Sema/SemaConcept.cpp

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ CalculateTemplateDepthForConstraints(Sema &S, const NamedDecl *ND,
721721
ND, /*Final=*/false, /*Innermost=*/nullptr, /*RelativeToPrimary=*/true,
722722
/*Pattern=*/nullptr,
723723
/*ForConstraintInstantiation=*/true, SkipForSpecialization);
724-
return MLTAL.getNumSubstitutedLevels();
724+
return MLTAL.getNumLevels();
725725
}
726726

727727
namespace {
@@ -752,27 +752,44 @@ namespace {
752752
};
753753
} // namespace
754754

755+
static const Expr *SubstituteConstraintExpression(Sema &S, const NamedDecl *ND,
756+
const Expr *ConstrExpr) {
757+
MultiLevelTemplateArgumentList MLTAL = S.getTemplateInstantiationArgs(
758+
ND, /*Final=*/false, /*Innermost=*/nullptr,
759+
/*RelativeToPrimary=*/true,
760+
/*Pattern=*/nullptr, /*ForConstraintInstantiation=*/true,
761+
/*SkipForSpecialization*/ false);
762+
if (MLTAL.getNumSubstitutedLevels() == 0)
763+
return ConstrExpr;
764+
765+
Sema::SFINAETrap SFINAE(S, /*AccessCheckingSFINAE=*/false);
766+
std::optional<Sema::CXXThisScopeRAII> ThisScope;
767+
if (auto *RD = dyn_cast<CXXRecordDecl>(ND->getDeclContext()))
768+
ThisScope.emplace(S, const_cast<CXXRecordDecl *>(RD), Qualifiers());
769+
ExprResult SubstConstr =
770+
S.SubstConstraintExpr(const_cast<clang::Expr *>(ConstrExpr), MLTAL);
771+
if (SFINAE.hasErrorOccurred() || !SubstConstr.isUsable())
772+
return nullptr;
773+
return SubstConstr.get();
774+
}
775+
755776
bool Sema::AreConstraintExpressionsEqual(const NamedDecl *Old,
756777
const Expr *OldConstr,
757778
const NamedDecl *New,
758779
const Expr *NewConstr) {
780+
if (OldConstr == NewConstr)
781+
return true;
759782
if (Old && New && Old != New) {
760-
unsigned Depth1 = CalculateTemplateDepthForConstraints(
761-
*this, Old);
762-
unsigned Depth2 = CalculateTemplateDepthForConstraints(
763-
*this, New);
764-
765-
// Adjust the 'shallowest' verison of this to increase the depth to match
766-
// the 'other'.
767-
if (Depth2 > Depth1) {
768-
OldConstr = AdjustConstraintDepth(*this, Depth2 - Depth1)
769-
.TransformExpr(const_cast<Expr *>(OldConstr))
770-
.get();
771-
} else if (Depth1 > Depth2) {
772-
NewConstr = AdjustConstraintDepth(*this, Depth1 - Depth2)
773-
.TransformExpr(const_cast<Expr *>(NewConstr))
774-
.get();
775-
}
783+
if (const Expr *SubstConstr =
784+
SubstituteConstraintExpression(*this, Old, OldConstr))
785+
OldConstr = SubstConstr;
786+
else
787+
return false;
788+
if (const Expr *SubstConstr =
789+
SubstituteConstraintExpression(*this, New, NewConstr))
790+
NewConstr = SubstConstr;
791+
else
792+
return false;
776793
}
777794

778795
llvm::FoldingSetNodeID ID1, ID2;

clang/lib/Sema/SemaOverload.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1296,7 +1296,7 @@ bool Sema::IsOverload(FunctionDecl *New, FunctionDecl *Old,
12961296
// We check the return type and template parameter lists for function
12971297
// templates first; the remaining checks follow.
12981298
bool SameTemplateParameterList = TemplateParameterListsAreEqual(
1299-
NewTemplate->getTemplateParameters(),
1299+
NewTemplate, NewTemplate->getTemplateParameters(), OldTemplate,
13001300
OldTemplate->getTemplateParameters(), false, TPL_TemplateMatch);
13011301
bool SameReturnType = Context.hasSameType(Old->getDeclaredReturnType(),
13021302
New->getDeclaredReturnType());

clang/lib/Sema/SemaTemplateDeduction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2881,7 +2881,7 @@ CheckDeducedArgumentConstraints(Sema &S, TemplateDeclT *Template,
28812881
// not class-scope explicit specialization, so replace with Deduced Args
28822882
// instead of adding to inner-most.
28832883
if (NeedsReplacement)
2884-
MLTAL.replaceInnermostTemplateArguments(CanonicalDeducedArgs);
2884+
MLTAL.replaceInnermostTemplateArguments(Template, CanonicalDeducedArgs);
28852885

28862886
if (S.CheckConstraintSatisfaction(Template, AssociatedConstraints, MLTAL,
28872887
Info.getLocation(),

clang/lib/Sema/SemaTemplateInstantiate.cpp

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,38 @@ HandleDefaultTempArgIntoTempTempParam(const TemplateTemplateParmDecl *TTP,
131131
return Response::Done();
132132
}
133133

134+
Response HandlePartialClassTemplateSpec(
135+
const ClassTemplatePartialSpecializationDecl *PartialClassTemplSpec,
136+
MultiLevelTemplateArgumentList &Result, bool SkipForSpecialization) {
137+
// We don't want the arguments from the Partial Specialization, since
138+
// anything instantiating here cannot access the arguments from the
139+
// specialized template anyway, so any substitution we would do with these
140+
// partially specialized arguments would 'wrong' and confuse constraint
141+
// instantiation. We only do this in the case of a constraint check, since
142+
// code elsewhere actually uses these and replaces them later with what
143+
// they mean.
144+
// If we know this is the 'top level', we can replace this with an
145+
// OuterRetainedLevel, else we have to generate a set of identity arguments.
146+
147+
// If this is the top-level template entity, we can just add a retained level
148+
// and be done.
149+
if (!PartialClassTemplSpec->getTemplateDepth()) {
150+
if (!SkipForSpecialization)
151+
Result.addOuterRetainedLevel();
152+
return Response::Done();
153+
}
154+
155+
// Else, we can replace this with an 'empty' level, and the checking will just
156+
// alter the 'depth', since this we don't have the 'Index' for this level.
157+
if (!SkipForSpecialization)
158+
Result.addOuterTemplateArguments(
159+
const_cast<ClassTemplatePartialSpecializationDecl *>(
160+
PartialClassTemplSpec),
161+
{}, /*Final=*/false);
162+
163+
return Response::UseNextDecl(PartialClassTemplSpec);
164+
}
165+
134166
// Add template arguments from a class template instantiation.
135167
Response
136168
HandleClassTemplateSpec(const ClassTemplateSpecializationDecl *ClassTemplSpec,
@@ -208,6 +240,21 @@ Response HandleFunction(const FunctionDecl *Function,
208240
return Response::UseNextDecl(Function);
209241
}
210242

243+
Response HandleFunctionTemplateDecl(const FunctionTemplateDecl *FTD,
244+
MultiLevelTemplateArgumentList &Result) {
245+
if (!isa<ClassTemplateSpecializationDecl>(FTD->getDeclContext())) {
246+
NestedNameSpecifier *NNS = FTD->getTemplatedDecl()->getQualifier();
247+
const Type *Ty;
248+
const TemplateSpecializationType *TSTy;
249+
if (NNS && (Ty = NNS->getAsType()) &&
250+
(TSTy = Ty->getAs<TemplateSpecializationType>()))
251+
Result.addOuterTemplateArguments(const_cast<FunctionTemplateDecl *>(FTD),
252+
TSTy->template_arguments(),
253+
/*Final=*/false);
254+
}
255+
return Response::ChangeDecl(FTD->getLexicalDeclContext());
256+
}
257+
211258
Response HandleRecordDecl(const CXXRecordDecl *Rec,
212259
MultiLevelTemplateArgumentList &Result,
213260
ASTContext &Context,
@@ -218,15 +265,10 @@ Response HandleRecordDecl(const CXXRecordDecl *Rec,
218265
"Outer template not instantiated?");
219266
if (ClassTemplate->isMemberSpecialization())
220267
return Response::Done();
221-
if (ForConstraintInstantiation) {
222-
QualType RecordType = Context.getTypeDeclType(Rec);
223-
QualType Injected = cast<InjectedClassNameType>(RecordType)
224-
->getInjectedSpecializationType();
225-
const auto *InjectedType = cast<TemplateSpecializationType>(Injected);
268+
if (ForConstraintInstantiation)
226269
Result.addOuterTemplateArguments(const_cast<CXXRecordDecl *>(Rec),
227-
InjectedType->template_arguments(),
270+
ClassTemplate->getInjectedTemplateArgs(),
228271
/*Final=*/false);
229-
}
230272
}
231273

232274
bool IsFriend = Rec->getFriendObjectKind() ||
@@ -294,18 +336,23 @@ MultiLevelTemplateArgumentList Sema::getTemplateInstantiationArgs(
294336
// Accumulate the set of template argument lists in this structure.
295337
MultiLevelTemplateArgumentList Result;
296338

297-
if (Innermost)
339+
using namespace TemplateInstArgsHelpers;
340+
const Decl *CurDecl = ND;
341+
if (Innermost) {
298342
Result.addOuterTemplateArguments(const_cast<NamedDecl *>(ND),
299343
Innermost->asArray(), Final);
300-
301-
const Decl *CurDecl = ND;
344+
CurDecl = Response::UseNextDecl(ND).NextDecl;
345+
}
302346

303347
while (!CurDecl->isFileContextDecl()) {
304-
using namespace TemplateInstArgsHelpers;
305348
Response R;
306349
if (const auto *VarTemplSpec =
307350
dyn_cast<VarTemplateSpecializationDecl>(CurDecl)) {
308351
R = HandleVarTemplateSpec(VarTemplSpec, Result, SkipForSpecialization);
352+
} else if (const auto *PartialClassTemplSpec =
353+
dyn_cast<ClassTemplatePartialSpecializationDecl>(CurDecl)) {
354+
R = HandlePartialClassTemplateSpec(PartialClassTemplSpec, Result,
355+
SkipForSpecialization);
309356
} else if (const auto *ClassTemplSpec =
310357
dyn_cast<ClassTemplateSpecializationDecl>(CurDecl)) {
311358
R = HandleClassTemplateSpec(ClassTemplSpec, Result,
@@ -318,6 +365,8 @@ MultiLevelTemplateArgumentList Sema::getTemplateInstantiationArgs(
318365
} else if (const auto *CSD =
319366
dyn_cast<ImplicitConceptSpecializationDecl>(CurDecl)) {
320367
R = HandleImplicitConceptSpecializationDecl(CSD, Result);
368+
} else if (const auto *FTD = dyn_cast<FunctionTemplateDecl>(CurDecl)) {
369+
R = HandleFunctionTemplateDecl(FTD, Result);
321370
} else if (!isa<DeclContext>(CurDecl)) {
322371
R = Response::DontClearRelativeToPrimaryNextDecl(CurDecl);
323372
if (CurDecl->getDeclContext()->isTranslationUnit()) {

clang/lib/Sema/SemaTemplateInstantiateDecl.cpp

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,33 +1653,12 @@ Decl *TemplateDeclInstantiator::VisitClassTemplateDecl(ClassTemplateDecl *D) {
16531653
<< QualifierLoc.getSourceRange();
16541654
return nullptr;
16551655
}
1656-
1657-
if (PrevClassTemplate) {
1658-
const ClassTemplateDecl *MostRecentPrevCT =
1659-
PrevClassTemplate->getMostRecentDecl();
1660-
TemplateParameterList *PrevParams =
1661-
MostRecentPrevCT->getTemplateParameters();
1662-
1663-
// Make sure the parameter lists match.
1664-
if (!SemaRef.TemplateParameterListsAreEqual(
1665-
D->getTemplatedDecl(), InstParams,
1666-
MostRecentPrevCT->getTemplatedDecl(), PrevParams, true,
1667-
Sema::TPL_TemplateMatch))
1668-
return nullptr;
1669-
1670-
// Do some additional validation, then merge default arguments
1671-
// from the existing declarations.
1672-
if (SemaRef.CheckTemplateParameterList(InstParams, PrevParams,
1673-
Sema::TPC_ClassTemplate))
1674-
return nullptr;
1675-
}
16761656
}
16771657

16781658
CXXRecordDecl *RecordInst = CXXRecordDecl::Create(
16791659
SemaRef.Context, Pattern->getTagKind(), DC, Pattern->getBeginLoc(),
16801660
Pattern->getLocation(), Pattern->getIdentifier(), PrevDecl,
16811661
/*DelayTypeCreation=*/true);
1682-
16831662
if (QualifierLoc)
16841663
RecordInst->setQualifierInfo(QualifierLoc);
16851664

@@ -1689,16 +1668,38 @@ Decl *TemplateDeclInstantiator::VisitClassTemplateDecl(ClassTemplateDecl *D) {
16891668
ClassTemplateDecl *Inst
16901669
= ClassTemplateDecl::Create(SemaRef.Context, DC, D->getLocation(),
16911670
D->getIdentifier(), InstParams, RecordInst);
1692-
assert(!(isFriend && Owner->isDependentContext()));
1693-
Inst->setPreviousDecl(PrevClassTemplate);
1694-
16951671
RecordInst->setDescribedClassTemplate(Inst);
16961672

16971673
if (isFriend) {
1698-
if (PrevClassTemplate)
1674+
assert(!Owner->isDependentContext());
1675+
Inst->setLexicalDeclContext(Owner);
1676+
RecordInst->setLexicalDeclContext(Owner);
1677+
1678+
if (PrevClassTemplate) {
1679+
Inst->setCommonPtr(PrevClassTemplate->getCommonPtr());
1680+
RecordInst->setTypeForDecl(
1681+
PrevClassTemplate->getTemplatedDecl()->getTypeForDecl());
1682+
const ClassTemplateDecl *MostRecentPrevCT =
1683+
PrevClassTemplate->getMostRecentDecl();
1684+
TemplateParameterList *PrevParams =
1685+
MostRecentPrevCT->getTemplateParameters();
1686+
1687+
// Make sure the parameter lists match.
1688+
if (!SemaRef.TemplateParameterListsAreEqual(
1689+
RecordInst, InstParams, MostRecentPrevCT->getTemplatedDecl(),
1690+
PrevParams, true, Sema::TPL_TemplateMatch))
1691+
return nullptr;
1692+
1693+
// Do some additional validation, then merge default arguments
1694+
// from the existing declarations.
1695+
if (SemaRef.CheckTemplateParameterList(InstParams, PrevParams,
1696+
Sema::TPC_ClassTemplate))
1697+
return nullptr;
1698+
16991699
Inst->setAccess(PrevClassTemplate->getAccess());
1700-
else
1700+
} else {
17011701
Inst->setAccess(D->getAccess());
1702+
}
17021703

17031704
Inst->setObjectOfFriendDecl();
17041705
// TODO: do we want to track the instantiation progeny of this
@@ -1709,15 +1710,15 @@ Decl *TemplateDeclInstantiator::VisitClassTemplateDecl(ClassTemplateDecl *D) {
17091710
Inst->setInstantiatedFromMemberTemplate(D);
17101711
}
17111712

1713+
Inst->setPreviousDecl(PrevClassTemplate);
1714+
17121715
// Trigger creation of the type for the instantiation.
1713-
SemaRef.Context.getInjectedClassNameType(RecordInst,
1714-
Inst->getInjectedClassNameSpecialization());
1716+
SemaRef.Context.getInjectedClassNameType(
1717+
RecordInst, Inst->getInjectedClassNameSpecialization());
17151718

17161719
// Finish handling of friends.
17171720
if (isFriend) {
17181721
DC->makeDeclVisibleInContext(Inst);
1719-
Inst->setLexicalDeclContext(Owner);
1720-
RecordInst->setLexicalDeclContext(Owner);
17211722
return Inst;
17221723
}
17231724

clang/test/SemaTemplate/concepts-friends.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,27 @@ namespace NTTP {
441441
templ_func<1>(u2);
442442
}
443443
}
444+
445+
446+
namespace FriendOfFriend {
447+
448+
template <typename>
449+
concept Concept = true;
450+
451+
template <Concept> class FriendOfBar;
452+
453+
template <Concept> class Bar {
454+
template <Concept> friend class FriendOfBar;
455+
};
456+
457+
Bar<void> BarInstance;
458+
459+
namespace internal {
460+
void FriendOfFoo(FriendOfBar<void>);
461+
}
462+
463+
template <Concept> class Foo {
464+
friend void internal::FriendOfFoo(FriendOfBar<void>);
465+
};
466+
467+
} // namespace FriendOfFriend

0 commit comments

Comments
 (0)