Skip to content

Commit cf65343

Browse files
wsmosesvchuravy
andcommitted
AMD + NVIDIA Specific Fixes (#180)
* handle amdgcn intrinsics * constant addrspace is constant * make sure we don't use constant AS for shadow * Cleanup julia/amd Co-authored-by: Valentin Churavy <[email protected]>
1 parent 04ef6a7 commit cf65343

File tree

10 files changed

+204
-139
lines changed

10 files changed

+204
-139
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults &TR, Instruction *I) {
332332
case Intrinsic::nvvm_membar_cta:
333333
case Intrinsic::nvvm_membar_gl:
334334
case Intrinsic::nvvm_membar_sys:
335+
case Intrinsic::amdgcn_s_barrier:
335336
case Intrinsic::assume:
336337
case Intrinsic::stacksave:
337338
case Intrinsic::stackrestore:
@@ -548,6 +549,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
548549
case Intrinsic::nvvm_membar_cta:
549550
case Intrinsic::nvvm_membar_gl:
550551
case Intrinsic::nvvm_membar_sys:
552+
case Intrinsic::amdgcn_s_barrier:
551553
case Intrinsic::assume:
552554
case Intrinsic::stacksave:
553555
case Intrinsic::stackrestore:
@@ -1060,6 +1062,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
10601062
case Intrinsic::nvvm_membar_cta:
10611063
case Intrinsic::nvvm_membar_gl:
10621064
case Intrinsic::nvvm_membar_sys:
1065+
case Intrinsic::amdgcn_s_barrier:
10631066
case Intrinsic::assume:
10641067
case Intrinsic::stacksave:
10651068
case Intrinsic::stackrestore:
@@ -1532,6 +1535,7 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults &TR,
15321535
case Intrinsic::nvvm_membar_cta:
15331536
case Intrinsic::nvvm_membar_gl:
15341537
case Intrinsic::nvvm_membar_sys:
1538+
case Intrinsic::amdgcn_s_barrier:
15351539
case Intrinsic::assume:
15361540
case Intrinsic::stacksave:
15371541
case Intrinsic::stackrestore:

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,6 +1763,7 @@ class AdjointGenerator
17631763
case Intrinsic::nvvm_membar_cta:
17641764
case Intrinsic::nvvm_membar_gl:
17651765
case Intrinsic::nvvm_membar_sys:
1766+
case Intrinsic::amdgcn_s_barrier:
17661767

17671768
case Intrinsic::prefetch:
17681769
case Intrinsic::dbg_declare:
@@ -1849,6 +1850,7 @@ class AdjointGenerator
18491850
}
18501851

18511852
case Intrinsic::nvvm_barrier0:
1853+
case Intrinsic::amdgcn_s_barrier:
18521854
case Intrinsic::nvvm_membar_cta:
18531855
case Intrinsic::nvvm_membar_gl:
18541856
case Intrinsic::nvvm_membar_sys: {

enzyme/Enzyme/Enzyme.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,13 @@ class Enzyme : public ModulePass {
139139
unsigned truei = 0;
140140
IRBuilder<> Builder(CI);
141141

142-
bool AtomicAdd =
142+
auto Arch =
143143
llvm::Triple(
144144
CI->getParent()->getParent()->getParent()->getTargetTriple())
145-
.getArch() == Triple::nvptx ||
146-
llvm::Triple(
147-
CI->getParent()->getParent()->getParent()->getTargetTriple())
148-
.getArch() == Triple::nvptx64;
145+
.getArch();
146+
147+
bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
148+
Arch == Triple::amdgcn;
149149

150150
for (unsigned i = 1; i < CI->getNumArgOperands(); ++i) {
151151
Value *res = CI->getArgOperand(i);

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
#include "llvm/Analysis/BasicAliasAnalysis.h"
5454
#include "llvm/Analysis/GlobalsModRef.h"
5555

56+
#include "llvm/Support/AMDGPUMetadata.h"
57+
5658
#include "FunctionUtils.h"
5759
#include "GradientUtils.h"
5860
#include "LibraryFuncs.h"
@@ -219,6 +221,13 @@ struct CacheAnalysis {
219221
bool is_load_uncacheable(LoadInst &li) {
220222
assert(li.getParent()->getParent() == oldFunc);
221223

224+
auto Arch = llvm::Triple(oldFunc->getParent()->getTargetTriple()).getArch();
225+
if (Arch == Triple::amdgcn &&
226+
cast<PointerType>(li.getPointerOperand()->getType())
227+
->getAddressSpace() == 4) {
228+
return false;
229+
}
230+
222231
// Find the underlying object for the pointer operand of the load
223232
// instruction.
224233
auto obj =
@@ -244,7 +253,8 @@ struct CacheAnalysis {
244253
return false;
245254
}
246255
if (auto II = dyn_cast<IntrinsicInst>(inst2)) {
247-
if (II->getIntrinsicID() == Intrinsic::nvvm_barrier0) {
256+
if (II->getIntrinsicID() == Intrinsic::nvvm_barrier0 ||
257+
II->getIntrinsicID() == Intrinsic::amdgcn_s_barrier) {
248258
allUnsyncdPredecessorsOf(
249259
II,
250260
[&](Instruction *mid) {
@@ -1843,10 +1853,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
18431853
user->setCalledFunction(NewF);
18441854
}
18451855
}
1846-
if (llvm::Triple(NewF->getParent()->getTargetTriple()).getArch() ==
1847-
Triple::nvptx ||
1848-
llvm::Triple(NewF->getParent()->getTargetTriple()).getArch() ==
1849-
Triple::nvptx64)
1856+
auto Arch = llvm::Triple(NewF->getParent()->getTargetTriple()).getArch();
1857+
if (Arch == Triple::nvptx || Arch == Triple::nvptx64)
18501858
PPC.ReplaceReallocs(NewF, /*mem2reg*/ true);
18511859
if (PostOpt)
18521860
PPC.optimizeIntermediate(NewF);
@@ -2688,14 +2696,22 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
26882696

26892697
BasicBlock *entry = &gutils->newFunc->getEntryBlock();
26902698

2699+
auto Arch =
2700+
llvm::Triple(gutils->newFunc->getParent()->getTargetTriple()).getArch();
2701+
int SharedAddrSpace = Arch == Triple::amdgcn
2702+
? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local
2703+
: 3;
2704+
26912705
if (topLevel) {
26922706
BasicBlock *sharedBlock = nullptr;
26932707
for (auto &g : gutils->newFunc->getParent()->globals()) {
26942708
if (hasMetadata(&g, "enzyme_internalshadowglobal")) {
26952709
IRBuilder<> entryBuilder(gutils->inversionAllocs,
26962710
gutils->inversionAllocs->begin());
26972711

2698-
if (g.getType()->getAddressSpace() == 3) {
2712+
if ((Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
2713+
Arch == Triple::amdgcn) &&
2714+
g.getType()->getAddressSpace() == SharedAddrSpace) {
26992715
if (sharedBlock == nullptr)
27002716
sharedBlock = BasicBlock::Create(entry->getContext(), "shblock",
27012717
gutils->newFunc);
@@ -2718,24 +2734,34 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
27182734
entry->getTerminator()->eraseFromParent();
27192735
IRBuilder<> ebuilder(entry);
27202736

2721-
Value *tx = ebuilder.CreateCall(Intrinsic::getDeclaration(
2722-
gutils->newFunc->getParent(), Intrinsic::nvvm_read_ptx_sreg_tid_x));
2723-
tx = ebuilder.CreateICmpEQ(tx, ConstantInt::get(tx->getType(), 0));
2724-
Value *ty = ebuilder.CreateCall(Intrinsic::getDeclaration(
2725-
gutils->newFunc->getParent(), Intrinsic::nvvm_read_ptx_sreg_tid_y));
2726-
ty = ebuilder.CreateICmpEQ(ty, ConstantInt::get(ty->getType(), 0));
2727-
Value *tz = ebuilder.CreateCall(Intrinsic::getDeclaration(
2728-
gutils->newFunc->getParent(), Intrinsic::nvvm_read_ptx_sreg_tid_z));
2729-
tz = ebuilder.CreateICmpEQ(tz, ConstantInt::get(tz->getType(), 0));
2730-
2731-
ebuilder.CreateCondBr(ebuilder.CreateAnd(ebuilder.CreateAnd(tx, ty), tz),
2732-
sharedBlock, OldEntryInsts);
2737+
Value *tx, *ty, *tz;
2738+
if (Arch == Triple::nvptx || Arch == Triple::nvptx64) {
2739+
tx = ebuilder.CreateCall(Intrinsic::getDeclaration(
2740+
gutils->newFunc->getParent(), Intrinsic::nvvm_read_ptx_sreg_tid_x));
2741+
ty = ebuilder.CreateCall(Intrinsic::getDeclaration(
2742+
gutils->newFunc->getParent(), Intrinsic::nvvm_read_ptx_sreg_tid_y));
2743+
tz = ebuilder.CreateCall(Intrinsic::getDeclaration(
2744+
gutils->newFunc->getParent(), Intrinsic::nvvm_read_ptx_sreg_tid_z));
2745+
} else if (Arch == Triple::amdgcn) {
2746+
tx = ebuilder.CreateCall(Intrinsic::getDeclaration(
2747+
gutils->newFunc->getParent(), Intrinsic::amdgcn_workitem_id_x));
2748+
ty = ebuilder.CreateCall(Intrinsic::getDeclaration(
2749+
gutils->newFunc->getParent(), Intrinsic::amdgcn_workitem_id_y));
2750+
tz = ebuilder.CreateCall(Intrinsic::getDeclaration(
2751+
gutils->newFunc->getParent(), Intrinsic::amdgcn_workitem_id_z));
2752+
}
2753+
Value *AndVal = ebuilder.CreateAnd(ebuilder.CreateAnd(tx, ty), tz);
2754+
2755+
ebuilder.CreateCondBr(
2756+
ebuilder.CreateICmpEQ(AndVal, ConstantInt::get(AndVal->getType(), 0)),
2757+
sharedBlock, OldEntryInsts);
27332758

27342759
IRBuilder<> instbuilder(OldEntryInsts, OldEntryInsts->begin());
27352760

2761+
auto BarrierInst = Arch == Triple::amdgcn ? Intrinsic::amdgcn_s_barrier
2762+
: Intrinsic::nvvm_barrier0;
27362763
cast<CallInst>(instbuilder.CreateCall(
2737-
Intrinsic::getDeclaration(gutils->newFunc->getParent(),
2738-
Intrinsic::nvvm_barrier0),
2764+
Intrinsic::getDeclaration(gutils->newFunc->getParent(), BarrierInst),
27392765
{}));
27402766
OldEntryInsts->moveAfter(entry);
27412767
sharedBlock->moveAfter(entry);
@@ -2816,10 +2842,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
28162842
PreservedAnalyses PA;
28172843
PPC.FAM.invalidate(*gutils->newFunc, PA);
28182844
}
2819-
if (llvm::Triple(gutils->newFunc->getParent()->getTargetTriple()).getArch() ==
2820-
Triple::nvptx ||
2821-
llvm::Triple(gutils->newFunc->getParent()->getTargetTriple()).getArch() ==
2822-
Triple::nvptx64)
2845+
if (Arch == Triple::nvptx || Arch == Triple::nvptx64)
28232846
PPC.ReplaceReallocs(gutils->newFunc, /*mem2reg*/ true);
28242847
if (PostOpt)
28252848
PPC.optimizeIntermediate(gutils->newFunc);

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
#include "llvm/Analysis/ValueTracking.h"
4343
#include "llvm/IR/InstrTypes.h"
44+
#include "llvm/Support/AMDGPUMetadata.h"
4445
#include "llvm/Transforms/Utils/SimplifyIndVar.h"
4546

4647
std::map<std::string, std::function<llvm::Value *(IRBuilder<> &, CallInst *,
@@ -2045,11 +2046,16 @@ Value *GradientUtils::invertPointerM(Value *oval, IRBuilder<> &BuilderM) {
20452046
}
20462047
}
20472048

2048-
if ((llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch() ==
2049-
Triple::nvptx ||
2050-
llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch() ==
2051-
Triple::nvptx64) &&
2052-
cast<PointerType>(arg->getType())->getAddressSpace() == 3) {
2049+
auto Arch =
2050+
llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch();
2051+
int SharedAddrSpace =
2052+
Arch == Triple::amdgcn
2053+
? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local
2054+
: 3;
2055+
int AddrSpace = cast<PointerType>(arg->getType())->getAddressSpace();
2056+
if ((Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
2057+
Arch == Triple::amdgcn) &&
2058+
AddrSpace == SharedAddrSpace) {
20532059
llvm::errs() << "warning found shared memory\n";
20542060
//#if LLVM_VERSION_MAJOR >= 11
20552061
Type *type = cast<PointerType>(arg->getType())->getElementType();
@@ -2537,8 +2543,14 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
25372543
}
25382544
}
25392545
if (auto LI = dyn_cast<LoadInst>(inst)) {
2546+
auto Arch =
2547+
llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch();
2548+
unsigned int SharedAddrSpace =
2549+
Arch == Triple::amdgcn
2550+
? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local
2551+
: 3;
25402552
if (cast<PointerType>(LI->getPointerOperand()->getType())
2541-
->getAddressSpace() == 3) {
2553+
->getAddressSpace() == SharedAddrSpace) {
25422554
reduceRegister |= tryLegalRecomputeCheck &&
25432555
legalRecompute(LI, incoming_available, &BuilderM) &&
25442556
shouldRecompute(LI, incoming_available, &BuilderM);
@@ -2850,8 +2862,16 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
28502862
}
28512863

28522864
auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand());
2865+
2866+
auto Arch =
2867+
llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch();
2868+
unsigned int SharedAddrSpace =
2869+
Arch == Triple::amdgcn
2870+
? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local
2871+
: 3;
28532872
if (EnzymeSharedForward && scev1 != OrigSE.getCouldNotCompute() &&
2854-
cast<PointerType>(orig_liobj->getType())->getAddressSpace() == 3) {
2873+
cast<PointerType>(orig_liobj->getType())->getAddressSpace() ==
2874+
SharedAddrSpace) {
28552875
Value *resultValue = nullptr;
28562876
ValueToValueMapTy newavail;
28572877
for (const auto &pair : available) {
@@ -2882,7 +2902,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
28822902
return false;
28832903

28842904
if (auto II = dyn_cast<IntrinsicInst>(potentialAlias)) {
2885-
if (II->getIntrinsicID() == Intrinsic::nvvm_barrier0) {
2905+
if (II->getIntrinsicID() == Intrinsic::nvvm_barrier0 ||
2906+
II->getIntrinsicID() == Intrinsic::amdgcn_s_barrier) {
28862907
interveningSync =
28872908
DT.dominates(SI, II) && DT.dominates(II, origInst);
28882909
allUnsyncdPredecessorsOf(
@@ -2995,20 +3016,22 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
29953016
}
29963017
if (ls != ss) {
29973018
if (auto II = dyn_cast<IntrinsicInst>(svals[i])) {
2998-
if (II->getIntrinsicID() ==
2999-
Intrinsic::nvvm_read_ptx_sreg_tid_x ||
3000-
II->getIntrinsicID() ==
3001-
Intrinsic::nvvm_read_ptx_sreg_tid_y ||
3002-
II->getIntrinsicID() ==
3003-
Intrinsic::nvvm_read_ptx_sreg_tid_z) {
3019+
switch (II->getIntrinsicID()) {
3020+
case Intrinsic::nvvm_read_ptx_sreg_tid_x:
3021+
case Intrinsic::nvvm_read_ptx_sreg_tid_y:
3022+
case Intrinsic::nvvm_read_ptx_sreg_tid_z:
3023+
case Intrinsic::amdgcn_workitem_id_x:
3024+
case Intrinsic::amdgcn_workitem_id_y:
3025+
case Intrinsic::amdgcn_workitem_id_z:
30043026
ThreadLookup[getNewFromOriginal(II)] =
30053027
BuilderM.CreateZExtOrTrunc(
30063028
lookupM(getNewFromOriginal(lvals[i]),
30073029
BuilderM, available),
30083030
II->getType());
3009-
} else {
3010-
;
3031+
break;
3032+
default:
30113033
legal = false;
3034+
break;
30123035
}
30133036
} else {
30143037
legal = false;

enzyme/Enzyme/GradientUtils.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,18 +1553,26 @@ class DiffeGradientUtils : public GradientUtils {
15531553

15541554
// atomics
15551555
bool Atomic = AtomicAdd;
1556+
auto Arch = llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch();
15561557

15571558
// No need to do atomic on local memory for CUDA since it can't be raced
15581559
// upon
15591560
if (isa<AllocaInst>(TmpOrig) &&
1560-
(llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch() ==
1561-
Triple::nvptx ||
1562-
llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch() ==
1563-
Triple::nvptx64)) {
1561+
(Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
1562+
Arch == Triple::amdgcn)) {
15641563
Atomic = false;
15651564
}
15661565

15671566
if (Atomic) {
1567+
// For amdgcn constant AS is 4 and if the primal is in it we need to cast
1568+
// the derivative value to AS 1
1569+
auto AS = cast<PointerType>(ptr->getType())->getAddressSpace();
1570+
if (Arch == Triple::amdgcn && AS == 4) {
1571+
ptr = BuilderM.CreateAddrSpaceCast(
1572+
ptr, PointerType::get(
1573+
cast<PointerType>(ptr->getType())->getElementType(), 1));
1574+
}
1575+
15681576
/*
15691577
while (auto ASC = dyn_cast<AddrSpaceCastInst>(ptr)) {
15701578
ptr = ASC->getOperand(0);

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2172,6 +2172,9 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
21722172
case Intrinsic::nvvm_read_ptx_sreg_nctaid_y:
21732173
case Intrinsic::nvvm_read_ptx_sreg_nctaid_z:
21742174
case Intrinsic::nvvm_read_ptx_sreg_warpsize:
2175+
case Intrinsic::amdgcn_workitem_id_x:
2176+
case Intrinsic::amdgcn_workitem_id_y:
2177+
case Intrinsic::amdgcn_workitem_id_z:
21752178
// No direction check as always valid
21762179
updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1), &I);
21772180
return;
@@ -3598,6 +3601,9 @@ std::set<int64_t> FnTypeInfo::knownIntegralValues(
35983601
case Intrinsic::nvvm_read_ptx_sreg_ctaid_x:
35993602
case Intrinsic::nvvm_read_ptx_sreg_ctaid_y:
36003603
case Intrinsic::nvvm_read_ptx_sreg_ctaid_z:
3604+
case Intrinsic::amdgcn_workitem_id_x:
3605+
case Intrinsic::amdgcn_workitem_id_y:
3606+
case Intrinsic::amdgcn_workitem_id_z:
36013607
insert(0);
36023608
break;
36033609
default:

enzyme/Enzyme/Utils.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "llvm/IR/Dominators.h"
4747

4848
#if LLVM_VERSION_MAJOR >= 10
49+
#include "llvm/IR/IntrinsicsAMDGPU.h"
4950
#include "llvm/IR/IntrinsicsNVPTX.h"
5051
#endif
5152

@@ -653,7 +654,8 @@ allUnsyncdPredecessorsOf(llvm::Instruction *inst,
653654
for (auto uinst = inst->getPrevNode(); uinst != nullptr;
654655
uinst = uinst->getPrevNode()) {
655656
if (auto II = llvm::dyn_cast<llvm::IntrinsicInst>(uinst)) {
656-
if (II->getIntrinsicID() == llvm::Intrinsic::nvvm_barrier0) {
657+
if (II->getIntrinsicID() == llvm::Intrinsic::nvvm_barrier0 ||
658+
II->getIntrinsicID() == llvm::Intrinsic::amdgcn_s_barrier) {
657659
return;
658660
}
659661
}
@@ -677,7 +679,8 @@ allUnsyncdPredecessorsOf(llvm::Instruction *inst,
677679
llvm::BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend();
678680
for (; I != E; ++I) {
679681
if (auto II = llvm::dyn_cast<llvm::IntrinsicInst>(&*I)) {
680-
if (II->getIntrinsicID() == llvm::Intrinsic::nvvm_barrier0) {
682+
if (II->getIntrinsicID() == llvm::Intrinsic::nvvm_barrier0 ||
683+
II->getIntrinsicID() == llvm::Intrinsic::amdgcn_s_barrier) {
681684
syncd = true;
682685
break;
683686
}

0 commit comments

Comments
 (0)