Skip to content

Commit 88a5efc

Browse files
authored
Preserve NVVM for both intrinsics and library functions (rust-lang#504)
* Preserve NVVM for both intrinsics and library functions * Require preservation of NVVM intrinsics via Clang plugin * Update CMakeLists.txt
1 parent 2e075b7 commit 88a5efc

File tree

8 files changed

+178
-43
lines changed

8 files changed

+178
-43
lines changed

.github/workflows/bcload.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
run: |
2626
wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add -
2727
sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true
28-
sudo apt-get install -y autoconf cmake gcc g++ libtool gfortran llvm-${{ matrix.llvm }}-dev libomp-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} libeigen3-dev libboost-dev
28+
sudo apt-get install -y autoconf cmake gcc g++ libtool gfortran llvm-${{ matrix.llvm }}-dev libomp-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} libclang-${{ matrix.llvm }}-dev libeigen3-dev libboost-dev
2929
sudo python3 -m pip install --upgrade pip setuptools
3030
sudo python3 -m pip install lit
3131
sudo touch /usr/lib/llvm-${{ matrix.llvm }}/bin/yaml-bench

.github/workflows/ccpp.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
run: |
2727
wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add -
2828
sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true
29-
sudo apt-get install -y autoconf cmake gcc g++ libtool gfortran llvm-${{ matrix.llvm }}-dev libomp-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} libeigen3-dev libboost-dev
29+
sudo apt-get install -y autoconf cmake gcc g++ libtool gfortran llvm-${{ matrix.llvm }}-dev libomp-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} libclang-${{ matrix.llvm }}-dev libeigen3-dev libboost-dev
3030
sudo python3 -m pip install --upgrade pip setuptools
3131
sudo python3 -m pip install lit
3232
sudo touch /usr/lib/llvm-${{ matrix.llvm }}/bin/yaml-bench

enzyme/CMakeLists.txt

+59-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,18 @@ get_filename_component(LLVM_ABSOLUTE_DIR
4141

4242
set(LLVM_DIR "${LLVM_ABSOLUTE_DIR}" CACHE FILEPATH "b" FORCE)
4343

44-
message("found llvm dir " ${LLVM_DIR})
44+
if (NOT DEFINED LLVM_EXTERNAL_LIT)
45+
if(LLVM_DIR MATCHES ".*/cmake/llvm/?$")
46+
message("found llvm match ${CMAKE_MATCH_1} dir ${LLVM_DIR}")
47+
if (EXISTS ${LLVM_DIR}/../../../bin/llvm-lit)
48+
set(LLVM_EXTERNAL_LIT ${LLVM_DIR}/../../../bin/llvm-lit)
49+
endif()
50+
else()
51+
if (EXISTS ${LLVM_DIR}/bin/llvm-lit)
52+
set(LLVM_EXTERNAL_LIT ${LLVM_DIR}/bin/llvm-lit)
53+
endif()
54+
endif()
55+
endif()
4556

4657
get_filename_component(LLVM_ABSOLUTE_LIT
4758
"${LLVM_EXTERNAL_LIT}"
@@ -50,15 +61,55 @@ get_filename_component(LLVM_ABSOLUTE_LIT
5061
set(LLVM_EXTERNAL_LIT "${LLVM_ABSOLUTE_LIT}" CACHE FILEPATH "a" FORCE)
5162
message("found llvm lit " ${LLVM_EXTERNAL_LIT})
5263

53-
54-
5564
list(INSERT CMAKE_PREFIX_PATH 0 "${LLVM_DIR}")
65+
if (DEFINED Clang_DIR)
66+
get_filename_component(Clang_ABSOLUTE_DIR
67+
"${Clang_DIR}"
68+
REALPATH BASE_DIR "${CMAKE_BINARY_DIR}")
69+
set(Clang_DIR "${Clang_ABSOLUTE_DIR}" CACHE FILEPATH "b" FORCE)
70+
list(INSERT CMAKE_PREFIX_PATH 0 "${Clang_DIR}")
71+
message("clang dir defined ${Clang_DIR}")
72+
else()
73+
if(LLVM_DIR MATCHES ".*/cmake/llvm/?$")
74+
if (EXISTS ${LLVM_DIR}/../clang/../../libclangBasic.a)
75+
set(Clang_DIR ${LLVM_DIR}/../clang)
76+
list(INSERT CMAKE_PREFIX_PATH 0 "${Clang_DIR}")
77+
endif()
78+
elseif(LLVM_DIR MATCHES ".*/llvm-([0-9]+)/?$")
79+
if (EXISTS ${LLVM_DIR}/lib/libclangBasic.a)
80+
set(Clang_DIR ${LLVM_DIR}/lib/cmake/clang)
81+
list(INSERT CMAKE_PREFIX_PATH 0 "${Clang_DIR}")
82+
endif()
83+
else()
84+
if (EXISTS ${LLVM_DIR}/lib/libclangBasic.a)
85+
set(Clang_DIR ${LLVM_DIR})
86+
list(INSERT CMAKE_PREFIX_PATH 0 "${Clang_DIR}")
87+
endif()
88+
endif()
89+
message("clang dir from llvm ${Clang_DIR}")
90+
endif()
5691
message("CMAKE_PREFIX_PATH " ${CMAKE_PREFIX_PATH})
92+
5793
find_package(LLVM REQUIRED CONFIG)
5894

5995
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
6096
include(AddLLVM)
6197

98+
message("clang dir ${Clang_DIR}")
99+
100+
if (DEFINED Clang_DIR)
101+
find_package(Clang REQUIRED CONFIG PATHS ${Clang_DIR} NO_DEFAULT_PATH)
102+
if (${Clang_FOUND})
103+
include_directories(${CLANG_INCLUDE_DIRS})
104+
message("clang inc dir ${CLANG_INCLUDE_DIRS}")
105+
add_definitions(${CLANG_DEFINITIONS})
106+
endif()
107+
else()
108+
set(Clang_FOUND 0)
109+
endif()
110+
message("found ${Clang_FOUND}")
111+
# include(AddClang)
112+
62113
add_definitions(${LLVM_DEFINITIONS})
63114
include_directories(${LLVM_INCLUDE_DIRS})
64115
message("LLVM_INSTALL_PREFIX: ${LLVM_INSTALL_PREFIX}")
@@ -132,9 +183,13 @@ if(NOT IS_ABSOLUTE "${${var}}")
132183
endif()
133184
endforeach()
134185

135-
export(TARGETS LLVMEnzyme-${LLVM_VERSION_MAJOR} ClangEnzyme-${LLVM_VERSION_MAJOR}
186+
export(TARGETS LLVMEnzyme-${LLVM_VERSION_MAJOR}
136187
FILE "${PROJECT_BINARY_DIR}/EnzymeTargets.cmake")
137188

189+
if (${Clang_FOUND})
190+
export(TARGETS ClangEnzyme-${LLVM_VERSION_MAJOR}
191+
APPEND FILE "${PROJECT_BINARY_DIR}/EnzymeTargets.cmake")
192+
endif()
138193
export(PACKAGE Enzyme)
139194

140195
set(CONF_LLVM_VERSION_MAJOR ${LLVM_VERSION_MAJOR})

enzyme/Enzyme/CMakeLists.txt

+17-1
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ if (${LLVM_VERSION_MAJOR} LESS 8)
1919
PLUGIN_TOOL
2020
opt
2121
)
22+
if (${Clang_FOUND})
2223
add_llvm_loadable_module( ClangEnzyme-${LLVM_VERSION_MAJOR}
2324
${ENZYME_SRC} Clang/EnzymeClang.cpp
2425
DEPENDS
2526
intrinsics_gen
2627
PLUGIN_TOOL
2728
opt
2829
)
30+
endif()
2931
else()
3032
# on windows `PLUGIN_TOOL` doesn't link against LLVM.dll
3133
if ((WIN32 OR CYGWIN) AND LLVM_LINK_LLVM_DYLIB)
@@ -37,6 +39,7 @@ if ((WIN32 OR CYGWIN) AND LLVM_LINK_LLVM_DYLIB)
3739
LINK_COMPONENTS
3840
LLVM
3941
)
42+
if (${Clang_FOUND})
4043
add_llvm_library( ClangEnzyme-${LLVM_VERSION_MAJOR}
4144
${ENZYME_SRC} Clang/EnzymeClang.cpp
4245
MODULE
@@ -45,6 +48,7 @@ if ((WIN32 OR CYGWIN) AND LLVM_LINK_LLVM_DYLIB)
4548
LINK_COMPONENTS
4649
LLVM
4750
)
51+
endif()
4852
else()
4953
add_llvm_library( LLVMEnzyme-${LLVM_VERSION_MAJOR}
5054
${ENZYME_SRC}
@@ -54,6 +58,7 @@ else()
5458
PLUGIN_TOOL
5559
opt
5660
)
61+
if (${Clang_FOUND})
5762
add_llvm_library( ClangEnzyme-${LLVM_VERSION_MAJOR}
5863
${ENZYME_SRC} Clang/EnzymeClang.cpp
5964
MODULE
@@ -64,6 +69,7 @@ else()
6469
)
6570
endif()
6671
endif()
72+
endif()
6773

6874
if (${ENZYME_EXTERNAL_SHARED_LIB})
6975
add_library( Enzyme-${LLVM_VERSION_MAJOR}
@@ -82,12 +88,22 @@ if (APPLE)
8288
# Darwin-specific linker flags for loadable modules.
8389
set_target_properties(LLVMEnzyme-${LLVM_VERSION_MAJOR} PROPERTIES
8490
LINK_FLAGS "-Wl,-flat_namespace -Wl,-undefined -Wl,suppress")
91+
if (${Clang_FOUND})
8592
set_target_properties(ClangEnzyme-${LLVM_VERSION_MAJOR} PROPERTIES
8693
LINK_FLAGS "-Wl,-flat_namespace -Wl,-undefined -Wl,suppress")
8794
endif()
95+
endif()
96+
97+
install(TARGETS LLVMEnzyme-${LLVM_VERSION_MAJOR}
98+
EXPORT EnzymeTargets
99+
LIBRARY DESTINATION lib COMPONENT shlib
100+
PUBLIC_HEADER DESTINATION "${INSTALL_INCLUDE_DIR}/Enzyme"
101+
COMPONENT dev)
88102

89-
install(TARGETS LLVMEnzyme-${LLVM_VERSION_MAJOR} ClangEnzyme-${LLVM_VERSION_MAJOR}
103+
if (${Clang_FOUND})
104+
install(TARGETS ClangEnzyme-${LLVM_VERSION_MAJOR}
90105
EXPORT EnzymeTargets
91106
LIBRARY DESTINATION lib COMPONENT shlib
92107
PUBLIC_HEADER DESTINATION "${INSTALL_INCLUDE_DIR}/Enzyme"
93108
COMPONENT dev)
109+
endif()

enzyme/Enzyme/Clang/EnzymeClang.cpp

+59
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,63 @@ static void loadLTOPass(const PassManagerBuilder &Builder,
8383
static RegisterStandardPasses
8484
clangtoolLoader_LTO(PassManagerBuilder::EP_FullLinkTimeOptimizationEarly,
8585
loadLTOPass);
86+
87+
#include "clang/AST/Attr.h"
88+
#include "clang/AST/DeclGroup.h"
89+
#include "clang/Frontend/CompilerInstance.h"
90+
#include "clang/Frontend/FrontendAction.h"
91+
#include "clang/Frontend/FrontendPluginRegistry.h"
92+
93+
template <typename ConsumerType>
94+
class EnzymeAction : public clang::PluginASTAction {
95+
protected:
96+
std::unique_ptr<clang::ASTConsumer>
97+
CreateASTConsumer(clang::CompilerInstance &CI, llvm::StringRef InFile) {
98+
return std::unique_ptr<clang::ASTConsumer>(new ConsumerType(CI));
99+
}
100+
101+
bool ParseArgs(const clang::CompilerInstance &CI,
102+
const std::vector<std::string> &args) {
103+
return true;
104+
}
105+
106+
PluginASTAction::ActionType getActionType() override {
107+
return AddBeforeMainAction;
108+
}
109+
};
110+
111+
class EnzymePlugin : public clang::ASTConsumer {
112+
clang::CompilerInstance &CI;
113+
114+
public:
115+
EnzymePlugin(clang::CompilerInstance &CI) : CI(CI) {}
116+
~EnzymePlugin() {}
117+
bool HandleTopLevelDecl(clang::DeclGroupRef dg) override {
118+
using namespace clang;
119+
DeclGroupRef::iterator it;
120+
121+
// Forcibly require emission of all libdevice
122+
for (it = dg.begin(); it != dg.end(); ++it) {
123+
auto FD = dyn_cast<FunctionDecl>(*it);
124+
if (!FD)
125+
continue;
126+
127+
if (!FD->hasAttr<clang::CUDADeviceAttr>())
128+
continue;
129+
130+
if (!FD->getIdentifier())
131+
continue;
132+
if (!StringRef(FD->getLocation().printToString(CI.getSourceManager()))
133+
.contains("/__clang_cuda_math.h"))
134+
continue;
135+
136+
FD->addAttr(UsedAttr::CreateImplicit(CI.getASTContext()));
137+
}
138+
return true;
139+
}
140+
};
141+
142+
// register the PluginASTAction in the registry.
143+
static clang::FrontendPluginRegistry::Add<EnzymeAction<EnzymePlugin>>
144+
X("enzyme", "Enzyme Plugin");
86145
#endif

enzyme/Enzyme/FunctionUtils.cpp

+31-27
Original file line numberDiff line numberDiff line change
@@ -2084,36 +2084,40 @@ void PreProcessCache::optimizeIntermediate(Function *F) {
20842084
// EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM);
20852085

20862086
for (Function &Impl : *F->getParent()) {
2087-
if (!Impl.hasFnAttribute("implements"))
2088-
continue;
2089-
const Attribute &A = Impl.getFnAttribute("implements");
2090-
2091-
const StringRef SpecificationName = A.getValueAsString();
2092-
Function *Specification = F->getParent()->getFunction(SpecificationName);
2093-
if (!Specification) {
2094-
LLVM_DEBUG(dbgs() << "Found implementation '" << Impl.getName()
2095-
<< "' but no matching specification with name '"
2096-
<< SpecificationName
2097-
<< "', potentially inlined and/or eliminated.\n");
2098-
continue;
2099-
}
2100-
LLVM_DEBUG(dbgs() << "Replace specification '" << Specification->getName()
2101-
<< "' with implementation '" << Impl.getName() << "'\n");
2102-
2103-
for (auto I = Specification->use_begin(), UE = Specification->use_end();
2104-
I != UE;) {
2105-
auto &use = *I;
2106-
++I;
2107-
auto cext = ConstantExpr::getBitCast(&Impl, Specification->getType());
2108-
use.set(cext);
2109-
if (auto CI = dyn_cast<CallInst>(use.getUser())) {
2087+
for (auto attr : {"implements", "implements2"}) {
2088+
if (!Impl.hasFnAttribute(attr))
2089+
continue;
2090+
const Attribute &A = Impl.getFnAttribute(attr);
2091+
2092+
const StringRef SpecificationName = A.getValueAsString();
2093+
Function *Specification = F->getParent()->getFunction(SpecificationName);
2094+
if (!Specification) {
2095+
LLVM_DEBUG(dbgs() << "Found implementation '" << Impl.getName()
2096+
<< "' but no matching specification with name '"
2097+
<< SpecificationName
2098+
<< "', potentially inlined and/or eliminated.\n");
2099+
continue;
2100+
}
2101+
LLVM_DEBUG(dbgs() << "Replace specification '" << Specification->getName()
2102+
<< "' with implementation '" << Impl.getName()
2103+
<< "'\n");
2104+
2105+
for (auto I = Specification->use_begin(), UE = Specification->use_end();
2106+
I != UE;) {
2107+
auto &use = *I;
2108+
++I;
2109+
auto cext = ConstantExpr::getBitCast(&Impl, Specification->getType());
2110+
use.set(cext);
2111+
if (auto CI = dyn_cast<CallInst>(use.getUser())) {
21102112
#if LLVM_VERSION_MAJOR >= 11
2111-
if (CI->getCalledOperand() == cext || CI->getCalledFunction() == &Impl)
2113+
if (CI->getCalledOperand() == cext ||
2114+
CI->getCalledFunction() == &Impl)
21122115
#else
2113-
if (CI->getCalledValue() == cext || CI->getCalledFunction() == &Impl)
2116+
if (CI->getCalledValue() == cext || CI->getCalledFunction() == &Impl)
21142117
#endif
2115-
{
2116-
CI->setCallingConv(Impl.getCallingConv());
2118+
{
2119+
CI->setCallingConv(Impl.getCallingConv());
2120+
}
21172121
}
21182122
}
21192123
}

enzyme/Enzyme/PreserveNVVM.cpp

+8-9
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,8 @@ class PreserveNVVM : public FunctionPass {
5555

5656
bool runOnFunction(Function &F) override {
5757
bool changed = false;
58-
std::map<std::pair<Type *, std::string>,
59-
std::pair<std::string, std::string>>
60-
Implements;
61-
for (Type *T : {Type::getFloatTy(F.getContext()),
62-
Type::getDoubleTy(F.getContext())}) {
58+
std::map<std::string, std::pair<std::string, std::string>> Implements;
59+
for (std::string T : {"", "f"}) {
6360
// sincos, sinpi, cospi, sincospi, cyl_bessel_i1
6461
for (std::string name :
6562
{"sin", "cos", "tan", "log2", "exp", "exp2",
@@ -75,19 +72,20 @@ class PreserveNVVM : public FunctionPass {
7572
"sqrt"}) {
7673
std::string nvname = "__nv_" + name;
7774
std::string llname = "llvm." + name + ".";
75+
std::string mathname = name;
7876

79-
if (T->isFloatTy()) {
77+
if (T == "f") {
78+
mathname += "f";
8079
nvname += "f";
8180
llname += "f32";
8281
} else {
8382
llname += "f64";
8483
}
8584

86-
Implements[std::make_pair(T, nvname)] = std::make_pair(name, llname);
85+
Implements[nvname] = std::make_pair(mathname, llname);
8786
}
8887
}
89-
auto idx = std::make_pair(F.getReturnType(), F.getName().str());
90-
auto found = Implements.find(idx);
88+
auto found = Implements.find(F.getName().str());
9189
if (found != Implements.end()) {
9290
if (Begin) {
9391
F.removeFnAttr(Attribute::AlwaysInline);
@@ -96,6 +94,7 @@ class PreserveNVVM : public FunctionPass {
9694
// cannot be erased.
9795
F.setLinkage(Function::LinkageTypes::ExternalLinkage);
9896
F.addFnAttr("implements", found->second.second);
97+
F.addFnAttr("implements2", found->second.first);
9998
F.addFnAttr("enzyme_math", found->second.first);
10099
changed = true;
101100
} else {

enzyme/test/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,7 @@ set(ENZYME_TEST_DEPS LLVMEnzyme-${LLVM_VERSION_MAJOR} BCPass-${LLVM_VERSION_MAJO
1010
add_subdirectory(ActivityAnalysis)
1111
add_subdirectory(TypeAnalysis)
1212
add_subdirectory(Enzyme)
13+
if (${Clang_FOUND})
1314
add_subdirectory(Integration)
15+
endif()
1416
add_subdirectory(BCLoader)

0 commit comments

Comments
 (0)