Skip to content

[DAE][SYCL] Enable DAE in SYCL kernel functions #2226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ void initializeCostModelAnalysisPass(PassRegistry&);
void initializeCrossDSOCFIPass(PassRegistry&);
void initializeDAEPass(PassRegistry&);
void initializeDAHPass(PassRegistry&);
void initializeDAESYCLPass(PassRegistry&);
void initializeDCELegacyPassPass(PassRegistry&);
void initializeDSELegacyPassPass(PassRegistry&);
void initializeDataFlowSanitizerPass(PassRegistry&);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/LinkAllPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ namespace {
(void) llvm::createControlHeightReductionLegacyPass();
(void) llvm::createCostModelAnalysisPass();
(void) llvm::createDeadArgEliminationPass();
(void) llvm::createDeadArgEliminationSYCLPass();
(void) llvm::createDeadCodeEliminationPass();
(void) llvm::createDeadInstEliminationPass();
(void) llvm::createDeadStoreEliminationPass();
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/Transforms/IPO.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ ModulePass *createDeadArgEliminationPass();
/// bugpoint.
ModulePass *createDeadArgHackingPass();

/// DeadArgumentElimination pass for SYCL kernel functions
ModulePass *createDeadArgEliminationSYCLPass();

//===----------------------------------------------------------------------===//
/// createArgumentPromotionPass - This pass promotes "by reference" arguments to
/// be passed by value if the number of elements passed is smaller or
Expand Down
10 changes: 8 additions & 2 deletions llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ class DeadArgumentEliminationPass
/// thus become dead in the end.
enum Liveness { Live, MaybeLive };

DeadArgumentEliminationPass(bool ShouldHackArguments_ = false)
: ShouldHackArguments(ShouldHackArguments_) {}
DeadArgumentEliminationPass(bool ShouldHackArguments_ = false,
bool CheckSpirKernels_ = false)
: ShouldHackArguments(ShouldHackArguments_),
CheckSpirKernels(CheckSpirKernels_) {}

PreservedAnalyses run(Module &M, ModuleAnalysisManager &);

Expand Down Expand Up @@ -121,6 +123,10 @@ class DeadArgumentEliminationPass
/// (used only by bugpoint).
bool ShouldHackArguments = false;

/// This allows to eliminate dead arguments in SPIR kernel functions with
/// external linkage in SYCL environment
bool CheckSpirKernels = false;

private:
Liveness MarkIfNotLive(RetOrArg Use, UseVector &MaybeLiveUses);
Liveness SurveyUse(const Use *U, UseVector &MaybeLiveUses,
Expand Down
130 changes: 128 additions & 2 deletions llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO.h"
Expand All @@ -54,6 +55,11 @@ using namespace llvm;

#define DEBUG_TYPE "deadargelim"

static cl::opt<std::string>
IntegrationHeaderFileName("integr-header-file",
cl::desc("Path to integration header file"),
cl::value_desc("filename"), cl::Hidden);

STATISTIC(NumArgumentsEliminated, "Number of unread args removed");
STATISTIC(NumRetValsEliminated , "Number of unused return values removed");
STATISTIC(NumArgumentsReplacedWithUndef,
Expand All @@ -77,13 +83,15 @@ namespace {
bool runOnModule(Module &M) override {
if (skipModule(M))
return false;
DeadArgumentEliminationPass DAEP(ShouldHackArguments());
DeadArgumentEliminationPass DAEP(ShouldHackArguments(),
CheckSpirKernels());
ModuleAnalysisManager DummyMAM;
PreservedAnalyses PA = DAEP.run(M, DummyMAM);
return !PA.areAllPreserved();
}

virtual bool ShouldHackArguments() const { return false; }
virtual bool CheckSpirKernels() const { return false; }
};

} // end anonymous namespace
Expand All @@ -103,6 +111,7 @@ namespace {
DAH() : DAE(ID) {}

bool ShouldHackArguments() const override { return true; }
bool CheckSpirKernels() const override { return false; }
};

} // end anonymous namespace
Expand All @@ -113,12 +122,42 @@ INITIALIZE_PASS(DAH, "deadarghaX0r",
"Dead Argument Hacking (BUGPOINT USE ONLY; DO NOT USE)",
false, false)

namespace {

/// DAESYCL - DeadArgumentElimination pass for SPIR kernel functions even
/// if they are external.
struct DAESYCL : public DAE {
static char ID;

DAESYCL() : DAE(ID) {
initializeDAESYCLPass(*PassRegistry::getPassRegistry());
}

StringRef getPassName() const override {
return "Dead Argument Elimination for SPIR kernels in SYCL environment";
}

bool ShouldHackArguments() const override { return false; }
bool CheckSpirKernels() const override { return true; }
};

} // end anonymous namespace

char DAESYCL::ID = 0;

INITIALIZE_PASS(
DAESYCL, "deadargelim-sycl",
"Dead Argument Elimination for SPIR kernels in SYCL environment", false,
false)

/// createDeadArgEliminationPass - This pass removes arguments from functions
/// which are not used by the body of the function.
ModulePass *llvm::createDeadArgEliminationPass() { return new DAE(); }

ModulePass *llvm::createDeadArgHackingPass() { return new DAH(); }

ModulePass *llvm::createDeadArgEliminationSYCLPass() { return new DAESYCL(); }

/// DeleteDeadVarargs - If this is an function that takes a ... list, and if
/// llvm.vastart is never called, the varargs list is dead for the function.
bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) {
Expand Down Expand Up @@ -535,7 +574,14 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) {
<< " has musttail calls\n");
}

if (!F.hasLocalLinkage() && (!ShouldHackArguments || F.isIntrinsic())) {
// We can't modify arguments if the function is not local
// but we can do so for SPIR kernel function in SYCL environment.
bool FuncIsSpirKernel =
CheckSpirKernels &&
StringRef(F.getParent()->getTargetTriple()).contains("sycldevice") &&
F.getCallingConv() == CallingConv::SPIR_KERNEL;
bool FuncIsLive = !F.hasLocalLinkage() && !FuncIsSpirKernel;
if (FuncIsLive && (!ShouldHackArguments || F.isIntrinsic())) {
MarkLive(F);
return;
}
Expand Down Expand Up @@ -714,6 +760,78 @@ void DeadArgumentEliminationPass::PropagateLiveness(const RetOrArg &RA) {
Uses.erase(Begin, I);
}

// Update kernel arguments table inside the integration header.
// For example:
// static constexpr const bool param_omit_table[] = {
// // OMIT_TABLE_BEGIN
// // kernel_name_1
// false, false, // <= update to true if the argument is dead
// // kernel_name_2
// false, false,
// // OMIT_TABLE_END
// };
// TODO: batch changes to multiple SPIR kernels and do one bulk update.
constexpr StringLiteral OMIT_TABLE_BEGIN("// OMIT_TABLE_BEGIN");
constexpr StringLiteral OMIT_TABLE_END("// OMIT_TABLE_END");
static void updateIntegrationHeader(StringRef SpirKernelName,
const ArrayRef<bool> &ArgAlive) {
ErrorOr<std::unique_ptr<MemoryBuffer>> IntHeaderBuffer =
MemoryBuffer::getFile(IntegrationHeaderFileName);

if (!IntHeaderBuffer)
report_fatal_error("unable to read integration header file '" +
IntegrationHeaderFileName +
"': " + IntHeaderBuffer.getError().message());
Comment on lines +781 to +784
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you going to replace this with assert(s)?
Another option to use LLVM_DEBUG and just skip the optimization.
Hard fail like this probably not the best approach for an optimization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a hard fail. If we did not update int header it will result in a later runtime fail.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we did not update int header it will result in a later runtime fail.

If we keep all arguments, it should be okay. I mean if anything is wrong with pre-requisites, this pass can just do nothing instead of crashing and it won't break anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduced an early exit if IntegrationHeaderFileName is not provided.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. I think we can replace all error checking with asserts in this file.
This way we can "hard fail" on build with enabled assertions and there is no overhead on internal consistency checking on the builds w/o assertions.

if (!<cond>)
  report_fatal_error(<msg>);

->

assert(<cond> & <msg>);

I don't think we should do thorough runtime checking for integration header format. This file is auto-generated by the compiler, so it's should be validated in scope of #2236. We can validate that file format is correct with asserts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to address this in a separate PR.


// 1. Find the region between OMIT_TABLE_BEGIN and OMIT_TABLE_END
StringRef IntHeader((*IntHeaderBuffer)->getBuffer());
if (!IntHeader.contains(OMIT_TABLE_BEGIN))
report_fatal_error(OMIT_TABLE_BEGIN +
" marker not found in integration header");
if (!IntHeader.contains(OMIT_TABLE_END))
report_fatal_error(OMIT_TABLE_END +
" marker not found in integration header");

size_t BeginRegionPos =
IntHeader.find(OMIT_TABLE_BEGIN) + OMIT_TABLE_BEGIN.size();
size_t EndRegionPos = IntHeader.find(OMIT_TABLE_END);

StringRef OmitArgTable = IntHeader.slice(BeginRegionPos, EndRegionPos);

// 2. Find the line that corresponds to the SPIR kernel
if (!OmitArgTable.contains(SpirKernelName))
report_fatal_error(
"Argument table not found in integration header for function '" +
SpirKernelName + "'");

size_t BeginLinePos =
OmitArgTable.find(SpirKernelName) + SpirKernelName.size();
size_t EndLinePos = OmitArgTable.find("//", BeginLinePos);

StringRef OmitArgLine = OmitArgTable.slice(BeginLinePos, EndLinePos);

size_t LineLeftTrim = OmitArgLine.size() - OmitArgLine.ltrim().size();
size_t LineRightTrim = OmitArgLine.size() - OmitArgLine.rtrim().size();

// 3. Construct new file contents and replace only that string.
std::string NewIntHeader;
NewIntHeader +=
IntHeader.take_front(BeginRegionPos + BeginLinePos + LineLeftTrim);
for (auto &AliveArg : ArgAlive)
NewIntHeader += AliveArg ? "false, " : "true, ";
NewIntHeader += IntHeader.drop_front(BeginRegionPos + BeginLinePos +
OmitArgLine.size() - LineRightTrim);

// 4. Flush the string into the file.
std::error_code EC;
raw_fd_ostream File(IntegrationHeaderFileName, EC, sys::fs::F_Text);

if (EC)
report_fatal_error("Cannot open integration header for writing.");

File << NewIntHeader;
}

// RemoveDeadStuffFromFunction - Remove any arguments and return values from F
// that are not in LiveValues. Transform the function and all of the callees of
// the function to not have these arguments and return values.
Expand Down Expand Up @@ -757,6 +875,9 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {
}
}

if (CheckSpirKernels)
updateIntegrationHeader(F->getName(), ArgAlive);

// Find out the new return value.
Type *RetTy = FTy->getReturnType();
Type *NRetTy = nullptr;
Expand Down Expand Up @@ -1072,6 +1193,11 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {

PreservedAnalyses DeadArgumentEliminationPass::run(Module &M,
ModuleAnalysisManager &) {
// Integration header file must be provided for
// DAE to work on SPIR kernels.
if (CheckSpirKernels && !IntegrationHeaderFileName.getNumOccurrences())
return PreservedAnalyses::all();

bool Changed = false;

// First pass: Do a simple check to see if any functions can have their "..."
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/IPO/IPO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ void llvm::initializeIPO(PassRegistry &Registry) {
initializeCrossDSOCFIPass(Registry);
initializeDAEPass(Registry);
initializeDAHPass(Registry);
initializeDAESYCLPass(Registry);
initializeForceFunctionAttrsLegacyPassPass(Registry);
initializeGlobalDCELegacyPassPass(Registry);
initializeGlobalOptLegacyPassPass(Registry);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/IPO/PassManagerBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ void PassManagerBuilder::populateModulePassManager(
if (RunInliner) {
MPM.add(createGlobalOptimizerPass());
MPM.add(createGlobalDCEPass());
MPM.add(createDeadArgEliminationSYCLPass());
}

// If we are planning to perform ThinLTO later, let's not bloat the code with
Expand Down
1 change: 1 addition & 0 deletions llvm/test/Other/opt-O2-pipeline.ll
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
; CHECK-NEXT: Branch Probability Analysis
; CHECK-NEXT: Block Frequency Analysis
; CHECK-NEXT: Dead Global Elimination
; CHECK-NEXT: Dead Argument Elimination for SPIR kernels in SYCL environment
; CHECK-NEXT: CallGraph Construction
; CHECK-NEXT: Globals Alias Analysis
; CHECK-NEXT: FunctionPass Manager
Expand Down
1 change: 1 addition & 0 deletions llvm/test/Other/opt-O3-pipeline-enable-matrix.ll
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@
; CHECK-NEXT: Branch Probability Analysis
; CHECK-NEXT: Block Frequency Analysis
; CHECK-NEXT: Dead Global Elimination
; CHECK-NEXT: Dead Argument Elimination for SPIR kernels in SYCL environment
; CHECK-NEXT: CallGraph Construction
; CHECK-NEXT: Globals Alias Analysis
; CHECK-NEXT: FunctionPass Manager
Expand Down
1 change: 1 addition & 0 deletions llvm/test/Other/opt-O3-pipeline.ll
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@
; CHECK-NEXT: Branch Probability Analysis
; CHECK-NEXT: Block Frequency Analysis
; CHECK-NEXT: Dead Global Elimination
; CHECK-NEXT: Dead Argument Elimination for SPIR kernels in SYCL environment
; CHECK-NEXT: CallGraph Construction
; CHECK-NEXT: Globals Alias Analysis
; CHECK-NEXT: FunctionPass Manager
Expand Down
1 change: 1 addition & 0 deletions llvm/test/Other/opt-Os-pipeline.ll
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
; CHECK-NEXT: Branch Probability Analysis
; CHECK-NEXT: Block Frequency Analysis
; CHECK-NEXT: Dead Global Elimination
; CHECK-NEXT: Dead Argument Elimination for SPIR kernels in SYCL environment
; CHECK-NEXT: CallGraph Construction
; CHECK-NEXT: Globals Alias Analysis
; CHECK-NEXT: FunctionPass Manager
Expand Down
17 changes: 17 additions & 0 deletions llvm/test/Transforms/DeadArgElim/sycl-kernels-neg1.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature
; RUN: opt < %s -deadargelim-sycl -S | FileCheck %s
; Skip DAE if the path to the integration header is not specified.

target triple = "spir64-unknown-unknown-sycldevice"

define weak_odr spir_kernel void @NegativeSpirKernel(float %arg1, float %arg2) {
; CHECK-LABEL: define {{[^@]+}}@NegativeSpirKernel
; CHECK-SAME: (float [[ARG1:%.*]], float [[ARG2:%.*]])
; CHECK-NEXT: call void @foo(float [[ARG1]])
; CHECK-NEXT: ret void
;
call void @foo(float %arg1)
ret void
}

declare void @foo(float %arg)
13 changes: 13 additions & 0 deletions llvm/test/Transforms/DeadArgElim/sycl-kernels-neg2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
; RUN: touch %t-int_header.h
; RUN: not --crash opt < %s -deadargelim-sycl -S -integr-header-file %t-bad_file.h

; Path to the integration header is wrong.

target triple = "spir64-unknown-unknown-sycldevice"

define weak_odr spir_kernel void @NegativeSpirKernel(float %arg1, float %arg2) {
call void @foo(float %arg1)
ret void
}

declare void @foo(float %arg)
16 changes: 16 additions & 0 deletions llvm/test/Transforms/DeadArgElim/sycl-kernels-neg3.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
; RUN: echo 'static constexpr const bool param_omit_table[] = {' > %t-int_header.h
; RUN: echo ' // NegativeSpirKernel' >> %t-int_header.h
; RUN: echo ' false, false,' >> %t-int_header.h
; RUN: echo '};' >> %t-int_header.h
; RUN: not --crash opt < %s -deadargelim-sycl -S -integr-header-file %t-int_header.h

; No OMIT_TABLE markers in the integration header.

target triple = "spir64-unknown-unknown-sycldevice"

define weak_odr spir_kernel void @NegativeSpirKernel(float %arg1, float %arg2) {
call void @foo(float %arg1)
ret void
}

declare void @foo(float %arg)
18 changes: 18 additions & 0 deletions llvm/test/Transforms/DeadArgElim/sycl-kernels-neg4.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
; RUN: echo 'static constexpr const bool param_omit_table[] = {' > %t-int_header.h
; RUN: echo ' // OMIT_TABLE_BEGIN' >> %t-int_header.h
; RUN: echo ' // WrongKernelName' >> %t-int_header.h
; RUN: echo ' false, false,' >> %t-int_header.h
; RUN: echo ' // OMIT_TABLE_END' >> %t-int_header.h
; RUN: echo '};' >> %t-int_header.h
; RUN: not --crash opt < %s -deadargelim-sycl -S -integr-header-file %t-int_header.h

; Wrong kernel name in the integration header.

target triple = "spir64-unknown-unknown-sycldevice"

define weak_odr spir_kernel void @NegativeSpirKernel(float %arg1, float %arg2) {
call void @foo(float %arg1)
ret void
}

declare void @foo(float %arg)
20 changes: 20 additions & 0 deletions llvm/test/Transforms/DeadArgElim/sycl-kernels-neg5.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature
; RUN: opt < %s -deadargelim -S | FileCheck %s
; RUN: opt < %s -deadargelim-sycl -S | FileCheck %s

; This test ensures dead arguments are not eliminated
; from a global function that is not a SYCL kernel.

target triple = "spir64-unknown-unknown-sycldevice"

define weak_odr void @NotASpirKernel(float %arg1, float %arg2) {
; CHECK-LABEL: define {{[^@]+}}@NotASpirKernel
; CHECK-SAME: (float [[ARG1:%.*]], float [[ARG2:%.*]])
; CHECK-NEXT: call void @foo(float [[ARG1]])
; CHECK-NEXT: ret void
;
call void @foo(float %arg1)
ret void
}

declare void @foo(float %arg)
Loading