Skip to content

[Attributor] Look through indirect calls #65197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions llvm/include/llvm/Transforms/IPO/Attributor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2125,6 +2125,15 @@ struct Attributor {
const AAIsDead *FnLivenessAA,
DepClassTy DepClass = DepClassTy::OPTIONAL);

/// Check \p Pred on all potential Callees of \p CB.
///
/// This method will evaluate \p Pred with all potential callees of \p CB as
/// input and return true if \p Pred does. If some callees might be unknown
/// this function will return false.
bool checkForAllCallees(
function_ref<bool(ArrayRef<const Function *> Callees)> Pred,
const AbstractAttribute &QueryingAA, const CallBase &CB);

/// Check \p Pred on all (transitive) uses of \p V.
///
/// This method will evaluate \p Pred on all (transitive) uses of the
Expand Down Expand Up @@ -3295,7 +3304,7 @@ struct AbstractAttribute : public IRPosition, public AADepGraphNode {

/// Return true if this AA requires a "callee" (or an associted function) for
/// a call site positon. Default is optimistic to minimize AAs.
static bool requiresCalleeForCallBase() { return true; }
static bool requiresCalleeForCallBase() { return false; }

/// Return true if this AA requires non-asm "callee" for a call site positon.
static bool requiresNonAsmForCallBase() { return true; }
Expand Down Expand Up @@ -3852,9 +3861,6 @@ struct AANoAlias
Attribute::AttrKind ImpliedAttributeKind,
bool IgnoreSubsumingPositions = false);

/// See AbstractAttribute::requiresCalleeForCallBase
static bool requiresCalleeForCallBase() { return false; }

/// See AbstractAttribute::requiresCallersForArgOrFunction
static bool requiresCallersForArgOrFunction() { return true; }

Expand Down Expand Up @@ -4699,6 +4705,9 @@ struct AAMemoryLocation

AAMemoryLocation(const IRPosition &IRP, Attributor &A) : IRAttribute(IRP) {}

/// See AbstractAttribute::requiresCalleeForCallBase.
static bool requiresCalleeForCallBase() { return true; }

/// See AbstractAttribute::hasTrivialInitializer.
static bool hasTrivialInitializer() { return false; }

Expand Down Expand Up @@ -5481,10 +5490,6 @@ struct AACallEdges : public StateWrapper<BooleanState, AbstractAttribute>,
AACallEdges(const IRPosition &IRP, Attributor &A)
: Base(IRP), AACallGraphNode(A) {}

/// The callee value is tracked beyond a simple stripPointerCasts, so we allow
/// unknown callees.
static bool requiresCalleeForCallBase() { return false; }

/// See AbstractAttribute::requiresNonAsmForCallBase.
static bool requiresNonAsmForCallBase() { return false; }

Expand Down Expand Up @@ -6310,9 +6315,6 @@ struct AAIndirectCallInfo
AAIndirectCallInfo(const IRPosition &IRP, Attributor &A)
: StateWrapper<BooleanState, AbstractAttribute>(IRP) {}

/// The point is to derive callees, after all.
static bool requiresCalleeForCallBase() { return false; }

/// See AbstractAttribute::isValidIRPositionForInit
static bool isValidIRPositionForInit(Attributor &A, const IRPosition &IRP) {
if (IRP.getPositionKind() != IRPosition::IRP_CALL_SITE)
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Transforms/IPO/Attributor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1733,6 +1733,21 @@ bool Attributor::isAssumedDead(const BasicBlock &BB,
return false;
}

bool Attributor::checkForAllCallees(
function_ref<bool(ArrayRef<const Function *>)> Pred,
const AbstractAttribute &QueryingAA, const CallBase &CB) {
if (const Function *Callee = dyn_cast<Function>(CB.getCalledOperand()))
return Pred(Callee);

const auto *CallEdgesAA = getAAFor<AACallEdges>(
QueryingAA, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
if (!CallEdgesAA || CallEdgesAA->hasUnknownCallee())
return false;

const auto &Callees = CallEdgesAA->getOptimisticEdges();
return Pred(Callees.getArrayRef());
}

bool Attributor::checkForAllUses(
function_ref<bool(const Use &, bool &)> Pred,
const AbstractAttribute &QueryingAA, const Value &V,
Expand Down
52 changes: 29 additions & 23 deletions llvm/lib/Transforms/IPO/AttributorAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,36 +604,42 @@ struct AACalleeToCallSite : public BaseType {
"returned positions!");
auto &S = this->getState();

const Function *AssociatedFunction =
this->getIRPosition().getAssociatedFunction();
if (!AssociatedFunction)
return S.indicatePessimisticFixpoint();

CallBase &CB = cast<CallBase>(this->getAnchorValue());
if (IntroduceCallBaseContext)
LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:" << CB
<< "\n");

IRPosition FnPos =
IRPKind == llvm::IRPosition::IRP_CALL_SITE_RETURNED
? IRPosition::returned(*AssociatedFunction,
IntroduceCallBaseContext ? &CB : nullptr)
: IRPosition::function(*AssociatedFunction,
IntroduceCallBaseContext ? &CB : nullptr);

// If possible, use the hasAssumedIRAttr interface.
if (Attribute::isEnumAttrKind(IRAttributeKind)) {
bool IsKnown;
if (!AA::hasAssumedIRAttr<IRAttributeKind>(A, this, FnPos,
DepClassTy::REQUIRED, IsKnown))
return S.indicatePessimisticFixpoint();
return ChangeStatus::UNCHANGED;
}
ChangeStatus Changed = ChangeStatus::UNCHANGED;
auto CalleePred = [&](ArrayRef<const Function *> Callees) {
for (const Function *Callee : Callees) {
IRPosition FnPos =
IRPKind == llvm::IRPosition::IRP_CALL_SITE_RETURNED
? IRPosition::returned(*Callee,
IntroduceCallBaseContext ? &CB : nullptr)
: IRPosition::function(
*Callee, IntroduceCallBaseContext ? &CB : nullptr);
// If possible, use the hasAssumedIRAttr interface.
if (Attribute::isEnumAttrKind(IRAttributeKind)) {
bool IsKnown;
if (!AA::hasAssumedIRAttr<IRAttributeKind>(
A, this, FnPos, DepClassTy::REQUIRED, IsKnown))
return false;
continue;
}

const AAType *AA = A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED);
if (!AA)
const AAType *AA =
A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED);
if (!AA)
return false;
Changed |= clampStateAndIndicateChange(S, AA->getState());
if (S.isAtFixpoint())
return S.isValidState();
}
return true;
};
if (!A.checkForAllCallees(CalleePred, *this, CB))
return S.indicatePessimisticFixpoint();
return clampStateAndIndicateChange(S, AA->getState());
return Changed;
}
};

Expand Down
6 changes: 3 additions & 3 deletions llvm/test/Transforms/Attributor/liveness.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2359,12 +2359,12 @@ define void @call_via_pointer_with_dead_args(ptr %a, ptr %b, ptr %fp) {
define internal void @call_via_pointer_with_dead_args_internal_a(ptr %a, ptr %b, ptr %fp) {
; TUNIT-LABEL: define {{[^@]+}}@call_via_pointer_with_dead_args_internal_a
; TUNIT-SAME: (ptr [[A:%.*]], ptr noundef nonnull align 128 dereferenceable(4) [[B:%.*]]) {
; TUNIT-NEXT: call void @called_via_pointer(ptr [[A]], ptr [[B]], ptr [[A]], i64 -1, ptr null)
; TUNIT-NEXT: call void @called_via_pointer(ptr [[A]], ptr nonnull align 128 dereferenceable(4) [[B]], ptr [[A]], i64 -1, ptr null)
; TUNIT-NEXT: ret void
;
; CGSCC-LABEL: define {{[^@]+}}@call_via_pointer_with_dead_args_internal_a
; CGSCC-SAME: (ptr [[A:%.*]], ptr noundef nonnull align 128 dereferenceable(4) [[B:%.*]]) {
; CGSCC-NEXT: call void @called_via_pointer(ptr [[A]], ptr nocapture nofree noundef nonnull [[B]], ptr nocapture nofree [[A]], i64 noundef -1, ptr nofree noundef null)
; CGSCC-NEXT: call void @called_via_pointer(ptr [[A]], ptr nocapture nofree noundef nonnull align 128 dereferenceable(4) [[B]], ptr nocapture nofree [[A]], i64 noundef -1, ptr nofree noundef null)
; CGSCC-NEXT: ret void
;
call void %fp(ptr %a, ptr %b, ptr %a, i64 -1, ptr null)
Expand All @@ -2373,7 +2373,7 @@ define internal void @call_via_pointer_with_dead_args_internal_a(ptr %a, ptr %b,
define internal void @call_via_pointer_with_dead_args_internal_b(ptr %a, ptr %b, ptr %fp) {
; TUNIT-LABEL: define {{[^@]+}}@call_via_pointer_with_dead_args_internal_b
; TUNIT-SAME: (ptr [[A:%.*]], ptr noundef nonnull align 128 dereferenceable(4) [[B:%.*]]) {
; TUNIT-NEXT: call void @called_via_pointer_internal_2(ptr [[A]], ptr [[B]], ptr [[A]], i64 -1, ptr null)
; TUNIT-NEXT: call void @called_via_pointer_internal_2(ptr [[A]], ptr nonnull align 128 dereferenceable(4) [[B]], ptr [[A]], i64 -1, ptr null)
; TUNIT-NEXT: ret void
;
; CGSCC-LABEL: define {{[^@]+}}@call_via_pointer_with_dead_args_internal_b
Expand Down
26 changes: 7 additions & 19 deletions llvm/test/Transforms/Attributor/nounwind.ll
Original file line number Diff line number Diff line change
Expand Up @@ -151,27 +151,14 @@ define i32 @catch_thing_user() {
}

define void @two_potential_callees_pos1(i1 %c) {
; TUNIT: Function Attrs: norecurse
; TUNIT: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
; TUNIT-LABEL: define {{[^@]+}}@two_potential_callees_pos1
; TUNIT-SAME: (i1 [[C:%.*]]) #[[ATTR3:[0-9]+]] {
; TUNIT-NEXT: [[FP:%.*]] = select i1 [[C]], ptr @foo1, ptr @scc1_foo
; TUNIT-NEXT: [[TMP1:%.*]] = icmp eq ptr [[FP]], @scc1_foo
; TUNIT-NEXT: br i1 [[TMP1]], label [[TMP2:%.*]], label [[TMP3:%.*]]
; TUNIT: 2:
; TUNIT-NEXT: call void @scc1_foo()
; TUNIT-NEXT: br label [[TMP6:%.*]]
; TUNIT: 3:
; TUNIT-NEXT: br i1 true, label [[TMP4:%.*]], label [[TMP5:%.*]]
; TUNIT: 4:
; TUNIT-NEXT: call void @foo1()
; TUNIT-NEXT: br label [[TMP6]]
; TUNIT: 5:
; TUNIT-NEXT: unreachable
; TUNIT: 6:
; TUNIT-SAME: (i1 [[C:%.*]]) #[[ATTR0]] {
; TUNIT-NEXT: ret void
;
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn
; CGSCC-LABEL: define {{[^@]+}}@two_potential_callees_pos1
; CGSCC-SAME: (i1 [[C:%.*]]) {
; CGSCC-SAME: (i1 [[C:%.*]]) #[[ATTR2:[0-9]+]] {
; CGSCC-NEXT: [[FP:%.*]] = select i1 [[C]], ptr @foo1, ptr @scc1_foo
; CGSCC-NEXT: [[TMP1:%.*]] = icmp eq ptr [[FP]], @scc1_foo
; CGSCC-NEXT: br i1 [[TMP1]], label [[TMP2:%.*]], label [[TMP3:%.*]]
Expand All @@ -193,8 +180,9 @@ define void @two_potential_callees_pos1(i1 %c) {
ret void
}
define void @two_potential_callees_pos2(i1 %c) {
; CHECK: Function Attrs: nounwind
; CHECK-LABEL: define {{[^@]+}}@two_potential_callees_pos2
; CHECK-SAME: (i1 [[C:%.*]]) {
; CHECK-SAME: (i1 [[C:%.*]]) #[[ATTR1]] {
; CHECK-NEXT: [[FP:%.*]] = select i1 [[C]], ptr @foo2, ptr @scc1_foo
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq ptr [[FP]], @scc1_foo
; CHECK-NEXT: br i1 [[TMP1]], label [[TMP2:%.*]], label [[TMP3:%.*]]
Expand Down Expand Up @@ -248,8 +236,8 @@ declare void @__cxa_end_catch()
; TUNIT: attributes #[[ATTR0]] = { mustprogress nofree norecurse nosync nounwind willreturn memory(none) }
; TUNIT: attributes #[[ATTR1]] = { nounwind }
; TUNIT: attributes #[[ATTR2]] = { mustprogress nofree nosync nounwind willreturn memory(none) }
; TUNIT: attributes #[[ATTR3]] = { norecurse }
;.
; CGSCC: attributes #[[ATTR0]] = { mustprogress nofree norecurse nosync nounwind willreturn memory(none) }
; CGSCC: attributes #[[ATTR1]] = { nounwind }
; CGSCC: attributes #[[ATTR2]] = { mustprogress nofree nosync nounwind willreturn }
;.
Loading