From ae9090aa72f8511bbecbd1f3690ef6e7f452e864 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Mon, 23 Sep 2024 22:36:12 -0400 Subject: [PATCH 1/5] [DirectX] Data Scalarization --- llvm/lib/Target/DirectX/CMakeLists.txt | 1 + .../Target/DirectX/DXILDataScalarization.cpp | 312 ++++++++++++++++++ .../Target/DirectX/DXILDataScalarization.h | 35 ++ llvm/lib/Target/DirectX/DirectX.h | 6 + .../Target/DirectX/DirectXTargetMachine.cpp | 2 + llvm/test/CodeGen/DirectX/scalar-load.ll | 40 +++ llvm/test/CodeGen/DirectX/scalar-store.ll | 34 +- 7 files changed, 418 insertions(+), 12 deletions(-) create mode 100644 llvm/lib/Target/DirectX/DXILDataScalarization.cpp create mode 100644 llvm/lib/Target/DirectX/DXILDataScalarization.h create mode 100644 llvm/test/CodeGen/DirectX/scalar-load.ll diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt index 7e0f8a145505e..c8ef0ef6f7e70 100644 --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -20,6 +20,7 @@ add_llvm_target(DirectXCodeGen DirectXTargetMachine.cpp DirectXTargetTransformInfo.cpp DXContainerGlobals.cpp + DXILDataScalarization.cpp DXILFinalizeLinkage.cpp DXILIntrinsicExpansion.cpp DXILOpBuilder.cpp diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp new file mode 100644 index 0000000000000..e689c0b06fccf --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -0,0 +1,312 @@ +//===- DXILDataScalarization.cpp - Prepare LLVM Module for DXIL Data +//Legalization----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--------------------------------------------------------------------------------===// + +#include "DXILDataScalarization.h" +#include "DirectX.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/ReplaceConstant.h" +#include "llvm/IR/Type.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" +#include + +#define DEBUG_TYPE "dxil-data-scalarization" +#define Max_VEC_SIZE 4 + +using namespace llvm; + +static void findAndReplaceVectors(Module &M); + +class DataScalarizerVisitor : public InstVisitor { +public: + DataScalarizerVisitor() : GlobalMap() {} + bool visit(Function &F); + // InstVisitor methods. They return true if the instruction was scalarized, + // false if nothing changed. + bool visitInstruction(Instruction &I) { return false; } + bool visitSelectInst(SelectInst &SI) { return false; } + bool visitICmpInst(ICmpInst &ICI) { return false; } + bool visitFCmpInst(FCmpInst &FCI) { return false; } + bool visitUnaryOperator(UnaryOperator &UO) { return false; } + bool visitBinaryOperator(BinaryOperator &BO) { return false; } + bool visitGetElementPtrInst(GetElementPtrInst &GEPI); + bool visitCastInst(CastInst &CI) { return false; } + bool visitBitCastInst(BitCastInst &BCI) { return false; } + bool visitInsertElementInst(InsertElementInst &IEI) { return false; } + bool visitExtractElementInst(ExtractElementInst &EEI) { return false; } + bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; } + bool visitPHINode(PHINode &PHI) { return false; } + bool visitLoadInst(LoadInst &LI); + bool visitStoreInst(StoreInst &SI); + bool visitCallInst(CallInst &ICI) { return false; } + bool visitFreezeInst(FreezeInst &FI) { return false; } + friend void findAndReplaceVectors(llvm::Module &M); + +private: + GlobalVariable *getNewGlobalIfExists(Value *CurrOperand); + DenseMap GlobalMap; + SmallVector PotentiallyDeadInstrs; + bool finish(); +}; + +bool DataScalarizerVisitor::visit(Function &F) { + assert(!GlobalMap.empty()); + ReversePostOrderTraversal RPOT(&F.getEntryBlock()); + for (BasicBlock *BB : RPOT) { + for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { + Instruction *I = &*II; + bool Done = InstVisitor::visit(I); + ++II; + if (Done && I->getType()->isVoidTy()) + I->eraseFromParent(); + } + } + return finish(); +} + +bool DataScalarizerVisitor::finish() { + // TODO this should do cleanup + RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); + return true; +} + +GlobalVariable * +DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) { + if (GlobalVariable *OldGlobal = dyn_cast(CurrOperand)) { + auto It = GlobalMap.find(OldGlobal); + if (It != GlobalMap.end()) { + return It->second; // Found, return the new global + } + } + return nullptr; // Not found +} + +bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) { + for (unsigned I = 0; I < LI.getNumOperands(); ++I) { + Value *CurrOpperand = LI.getOperand(I); + GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand); + if (NewGlobal) + LI.setOperand(I, NewGlobal); + } + return false; +} + +bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) { + bool ReplaceStore = false; + for (unsigned I = 0; I < SI.getNumOperands(); ++I) { + Value *CurrOpperand = SI.getOperand(I); + GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand); + if (NewGlobal) { + SI.setOperand(I, NewGlobal); + /*Value *StoredValue = SI.getValueOperand(); + Type *StoredType = StoredValue->getType(); + if (VectorType *VecTy = dyn_cast(StoredType)) { + unsigned NumElements = cast(VecTy)->getNumElements(); + ArrayType *ArrayTy = ArrayType::get(VecTy->getElementType(), + NumElements); std::vector ConstElements; if (ConstantVector + *ConstVec = dyn_cast(StoredValue)) { for(uint I = 0; I < + NumElements; I++) { ConstElements.push_back(ConstVec->getOperand(I)); + } + } + Value *ArrayValue = ConstantArray::get(ArrayTy,ConstElements); + IRBuilder<> Builder(&SI); + Builder.CreateStore(ArrayValue, SI.getPointerOperand()); + replaceStore = true; + }*/ + } + } + return ReplaceStore; +} + +bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { + for (unsigned I = 0; I < GEPI.getNumOperands(); ++I) { + Value *CurrOpperand = GEPI.getOperand(I); + GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand); + if (NewGlobal) { + // Prepare to create a new GEP for the new global + IRBuilder<> Builder(&GEPI); // Create an IRBuilder at the position of GEPI + + SmallVector Indices; + for (auto &Index : GEPI.indices()) + Indices.push_back(Index); + + // Create a new GEP for the new global variable + Value *NewGEP = + Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices); + + // Replace the old GEP with the new one + GEPI.replaceAllUsesWith(NewGEP); + PotentiallyDeadInstrs.emplace_back(&GEPI); + } + } + return true; +} + +static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) { + if (auto *VecTy = dyn_cast(T)) + return ArrayType::get(VecTy->getElementType(), + cast(VecTy)->getNumElements()); + if (auto *ArrayTy = dyn_cast(T)) { + Type *NewElementType = + replaceVectorWithArray(ArrayTy->getElementType(), Ctx); + return ArrayType::get(NewElementType, ArrayTy->getNumElements()); + } + // If it's not a vector or array, return the original type. + return T; +} + +Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, + LLVMContext &Ctx) { + // Handle ConstantAggregateZero (zero-initialized constants) + if (isa(Init)) { + return ConstantAggregateZero::get(NewType); + } + + // Handle UndefValue (undefined constants) + if (isa(Init)) { + return UndefValue::get(NewType); + } + + // Handle vector to array transformation + if (isa(OrigType) && isa(NewType)) { + // Convert vector initializer to array initializer + auto *VecInit = dyn_cast(Init); + if (!VecInit) { + llvm_unreachable("Expected a ConstantVector for vector initializer!"); + } + + SmallVector ArrayElements; + for (unsigned I = 0; I < VecInit->getNumOperands(); ++I) { + ArrayElements.push_back(VecInit->getOperand(I)); + } + + return ConstantArray::get(cast(NewType), ArrayElements); + } + + // Handle array of vectors transformation + if (auto *ArrayTy = dyn_cast(OrigType)) { + // Recursively transform array elements + auto *ArrayInit = dyn_cast(Init); + if (!ArrayInit) { + llvm_unreachable("Expected a ConstantArray for array initializer!"); + } + + SmallVector NewArrayElements; + for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) { + Constant *NewElemInit = transformInitializer( + ArrayInit->getOperand(I), ArrayTy->getElementType(), + cast(NewType)->getElementType(), Ctx); + NewArrayElements.push_back(NewElemInit); + } + + return ConstantArray::get(cast(NewType), NewArrayElements); + } + + // If not a vector or array, return the original initializer + return Init; +} + +static void findAndReplaceVectors(Module &M) { + LLVMContext &Ctx = M.getContext(); + IRBuilder<> Builder(Ctx); + DataScalarizerVisitor Impl; + for (GlobalVariable &G : M.globals()) { + Type *OrigType = G.getValueType(); + // Recursively replace vectors in the type + Type *NewType = replaceVectorWithArray(OrigType, Ctx); + if (OrigType != NewType) { + // Create a new global variable with the updated type + GlobalVariable *NewGlobal = new GlobalVariable( + M, NewType, G.isConstant(), G.getLinkage(), + // This is set via: transformInitializer + nullptr, G.getName() + ".scalarized", &G, G.getThreadLocalMode(), + G.getAddressSpace(), G.isExternallyInitialized()); + + // Copy relevant attributes + NewGlobal->setUnnamedAddr(G.getUnnamedAddr()); + if (G.getAlignment() > 0) { + NewGlobal->setAlignment(Align(G.getAlignment())); + } + + if (G.hasInitializer()) { + Constant *Init = G.getInitializer(); + Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx); + NewGlobal->setInitializer(NewInit); + } + + // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes + // type equality + // So instead we will use the visitor pattern + Impl.GlobalMap[&G] = NewGlobal; + for (User *U : G.users()) { + if (isa(U) && isa(U)) { + ConstantExpr *CE = cast(U); + convertUsersOfConstantsToInstructions(CE, + /*RestrictToFunc=*/nullptr, + /*RemoveDeadConstants=*/false, + /*IncludeSelf=*/true); + } + } + // Uses should have grown + std::vector UsersToProcess; + // Collect all users first + // work around so I can delete + // in a loop body + for (User *U : G.users()) { + UsersToProcess.push_back(U); + } + + // Now process each user + for (User *U : UsersToProcess) { + if (isa(U)) { + Instruction *Inst = cast(U); + Function *F = Inst->getFunction(); + if (F) + Impl.visit(*F); + } + } + } + } + + // Remove the old globals after the iteration + for (auto Pair : Impl.GlobalMap) { + GlobalVariable *OldG = Pair.getFirst(); + OldG->eraseFromParent(); + } +} + +PreservedAnalyses DXILDataScalarization::run(Module &M, + ModuleAnalysisManager &) { + findAndReplaceVectors(M); + return PreservedAnalyses::none(); +} + +bool DXILDataScalarizationLegacy::runOnModule(Module &M) { + findAndReplaceVectors(M); + return true; +} + +void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const {} + +char DXILDataScalarizationLegacy::ID = 0; + +INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE, + "DXIL Data Scalarization", false, false) +INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE, + "DXIL Data Scalarization", false, false) + +ModulePass *llvm::createDXILDataScalarizationLegacyPass() { + return new DXILDataScalarizationLegacy(); +} \ No newline at end of file diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.h b/llvm/lib/Target/DirectX/DXILDataScalarization.h new file mode 100644 index 0000000000000..d06119397ddb2 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.h @@ -0,0 +1,35 @@ +//===- DXILDataScalarization.h - Prepare LLVM Module for DXIL Data +//Legalization----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===------------------------------------------------------------------------------===// +#ifndef LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H +#define LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H + +#include "DXILResource.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" + +namespace llvm { + +/// A pass thattransforms Vectors to Arrays +class DXILDataScalarization : public PassInfoMixin { +public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &); +}; + +class DXILDataScalarizationLegacy : public ModulePass { + +public: + bool runOnModule(Module &M) override; + DXILDataScalarizationLegacy() : ModulePass(ID) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override; + static char ID; // Pass identification. +}; +} // namespace llvm + +#endif // LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h index 60fc5094542b3..3221779be2f31 100644 --- a/llvm/lib/Target/DirectX/DirectX.h +++ b/llvm/lib/Target/DirectX/DirectX.h @@ -34,6 +34,12 @@ void initializeDXILIntrinsicExpansionLegacyPass(PassRegistry &); /// Pass to expand intrinsic operations that lack DXIL opCodes ModulePass *createDXILIntrinsicExpansionLegacyPass(); +/// Initializer for DXIL Data Scalarization Pass +void initializeDXILDataScalarizationLegacyPass(PassRegistry &); + +/// Pass to scalarize llvm global data into a DXIL legal form +ModulePass *createDXILDataScalarizationLegacyPass(); + /// Initializer for DXILOpLowering void initializeDXILOpLoweringLegacyPass(PassRegistry &); diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp index 606022a9835f0..f358215ecf373 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -46,6 +46,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() { RegisterTargetMachine X(getTheDirectXTarget()); auto *PR = PassRegistry::getPassRegistry(); initializeDXILIntrinsicExpansionLegacyPass(*PR); + initializeDXILDataScalarizationLegacyPass(*PR); initializeScalarizerLegacyPassPass(*PR); initializeDXILPrepareModulePass(*PR); initializeEmbedDXILPassPass(*PR); @@ -86,6 +87,7 @@ class DirectXPassConfig : public TargetPassConfig { FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; } void addCodeGenPrepare() override { addPass(createDXILIntrinsicExpansionLegacyPass()); + addPass(createDXILDataScalarizationLegacyPass()); ScalarizerPassOptions DxilScalarOptions; DxilScalarOptions.ScalarizeLoadStore = true; addPass(createScalarizerPass(DxilScalarOptions)); diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll new file mode 100644 index 0000000000000..bd99b63883e9f --- /dev/null +++ b/llvm/test/CodeGen/DirectX/scalar-load.ll @@ -0,0 +1,40 @@ +; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s +; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s +@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16 +@"vecData" = external addrspace(3) global <4 x i32>, align 4 +@staticArrayOfVecData = internal global [3 x <4 x i32>] zeroinitializer, align 4 + +; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16 +; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4 +; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] zeroinitializer, align 4 +; CHECK-NOT: @arrayofVecData +; CHECK-NOT: @vecData +; CHECK-NOT: @staticArrayOfVecData + +; CHECK-LABEL: load_array_vec_test +define <4 x i32> @load_array_vec_test() { + ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align 4 + ; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4 + %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 0), align 4 + %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 1), align 4 + %3 = add <4 x i32> %1, %2 + ret <4 x i32> %3 +} + +; CHECK-LABEL: load_vec_test +define <4 x i32> @load_vec_test() { + ; CHECK-COUNT-4: load i32, ptr addrspace(3) {{(@vecData.scalarized|getelementptr \(i32, ptr addrspace\(3\) @vecData.scalarized, i32 .*\)|%.*)}}, align {{.*}} + ; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4 + %1 = load <4 x i32>, <4 x i32> addrspace(3)* @"vecData", align 4 + ret <4 x i32> %1 +} + +; CHECK-LABEL: load_static_array_of_vec_test +define <4 x i32> @load_static_array_of_vec_test(i32 %index) { + ; CHECK: getelementptr [3 x [4 x i32]], ptr @staticArrayOfVecData.scalarized, i32 0, i32 %index + ; CHECK-COUNT-4: load i32, ptr {{.*}}, align 4 + ; CHECK-NOT: load i32, ptr {{.*}}, align 4 + %3 = getelementptr inbounds [3 x <4 x i32>], [3 x <4 x i32>]* @staticArrayOfVecData, i32 0, i32 %index + %4 = load <4 x i32>, <4 x i32>* %3, align 4 + ret <4 x i32> %4 +} diff --git a/llvm/test/CodeGen/DirectX/scalar-store.ll b/llvm/test/CodeGen/DirectX/scalar-store.ll index b970a2842e5a8..767d2e47c3e8e 100644 --- a/llvm/test/CodeGen/DirectX/scalar-store.ll +++ b/llvm/test/CodeGen/DirectX/scalar-store.ll @@ -1,17 +1,27 @@ -; RUN: opt -S -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s +; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s -@"sharedData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16 -; CHECK-LABEL: store_test -define void @store_test () local_unnamed_addr { - ; CHECK: store float 1.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}} - ; CHECK: store float 2.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}} - ; CHECK: store float 3.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}} - ; CHECK: store float 2.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}} - ; CHECK: store float 4.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}} - ; CHECK: store float 6.000000e+00, ptr addrspace(3) {{.*}}, align {{.*}} +@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16 +@"vecData" = external addrspace(3) global <4 x i32>, align 4 - store <3 x float> , ptr addrspace(3) @"sharedData", align 16 - store <3 x float> , ptr addrspace(3) getelementptr inbounds (i8, ptr addrspace(3) @"sharedData", i32 16), align 16 +; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16 +; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4 +; CHECK-NOT: @arrayofVecData +; CHECK-NOT: @vecData + +; CHECK-LABEL: store_array_vec_test +define void @store_array_vec_test () local_unnamed_addr { + ; CHECK-COUNT-6: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align {{4|8|16}} + ; CHECK-NOT: store float {{1|2|3|4|6}}.000000e+00, ptr addrspace(3) {{(.*@arrayofVecData.scalarized.*|%.*)}}, align {{4|8|16}} + store <3 x float> , ptr addrspace(3) @"arrayofVecData", align 16 + store <3 x float> , ptr addrspace(3) getelementptr inbounds (i8, ptr addrspace(3) @"arrayofVecData", i32 16), align 16 ret void } + +; CHECK-LABEL: store_vec_test +define void @store_vec_test(<4 x i32> %inputVec) { + ; CHECK-COUNT-4: store i32 %inputVec.{{.*}}, ptr addrspace(3) {{(@vecData.scalarized|getelementptr \(i32, ptr addrspace\(3\) @vecData.scalarized, i32 .*\)|%.*)}}, align 4 + ; CHECK-NOT: store i32 %inputVec.{{.*}}, ptr addrspace(3) + store <4 x i32> %inputVec, <4 x i32> addrspace(3)* @"vecData", align 4 + ret void +} \ No newline at end of file From 1e89daf5684508de52ba66ff54c8ae2d8342cd5a Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Tue, 24 Sep 2024 02:29:05 -0400 Subject: [PATCH 2/5] cleanup comments and dead code --- .../Target/DirectX/DXILDataScalarization.cpp | 34 +++++-------------- .../Target/DirectX/DXILDataScalarization.h | 4 +-- 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp index e689c0b06fccf..b586e758c473f 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -1,5 +1,5 @@ //===- DXILDataScalarization.cpp - Prepare LLVM Module for DXIL Data -//Legalization----===// +// Legalization----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -77,7 +77,6 @@ bool DataScalarizerVisitor::visit(Function &F) { } bool DataScalarizerVisitor::finish() { - // TODO this should do cleanup RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); return true; } @@ -104,30 +103,14 @@ bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) { } bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) { - bool ReplaceStore = false; for (unsigned I = 0; I < SI.getNumOperands(); ++I) { Value *CurrOpperand = SI.getOperand(I); GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand); if (NewGlobal) { SI.setOperand(I, NewGlobal); - /*Value *StoredValue = SI.getValueOperand(); - Type *StoredType = StoredValue->getType(); - if (VectorType *VecTy = dyn_cast(StoredType)) { - unsigned NumElements = cast(VecTy)->getNumElements(); - ArrayType *ArrayTy = ArrayType::get(VecTy->getElementType(), - NumElements); std::vector ConstElements; if (ConstantVector - *ConstVec = dyn_cast(StoredValue)) { for(uint I = 0; I < - NumElements; I++) { ConstElements.push_back(ConstVec->getOperand(I)); - } - } - Value *ArrayValue = ConstantArray::get(ArrayTy,ConstElements); - IRBuilder<> Builder(&SI); - Builder.CreateStore(ArrayValue, SI.getPointerOperand()); - replaceStore = true; - }*/ } } - return ReplaceStore; + return false; } bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { @@ -135,18 +118,15 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { Value *CurrOpperand = GEPI.getOperand(I); GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand); if (NewGlobal) { - // Prepare to create a new GEP for the new global - IRBuilder<> Builder(&GEPI); // Create an IRBuilder at the position of GEPI + IRBuilder<> Builder(&GEPI); SmallVector Indices; for (auto &Index : GEPI.indices()) Indices.push_back(Index); - // Create a new GEP for the new global variable Value *NewGEP = Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices); - // Replace the old GEP with the new one GEPI.replaceAllUsesWith(NewGEP); PotentiallyDeadInstrs.emplace_back(&GEPI); } @@ -154,6 +134,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { return true; } +// Recursively Creates and Array like version of the given vector like type. static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) { if (auto *VecTy = dyn_cast(T)) return ArrayType::get(VecTy->getElementType(), @@ -197,7 +178,7 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, // Handle array of vectors transformation if (auto *ArrayTy = dyn_cast(OrigType)) { - // Recursively transform array elements + auto *ArrayInit = dyn_cast(Init); if (!ArrayInit) { llvm_unreachable("Expected a ConstantArray for array initializer!"); @@ -205,6 +186,7 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, SmallVector NewArrayElements; for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) { + // Recursively transform array elements Constant *NewElemInit = transformInitializer( ArrayInit->getOperand(I), ArrayTy->getElementType(), cast(NewType)->getElementType(), Ctx); @@ -224,7 +206,7 @@ static void findAndReplaceVectors(Module &M) { DataScalarizerVisitor Impl; for (GlobalVariable &G : M.globals()) { Type *OrigType = G.getValueType(); - // Recursively replace vectors in the type + Type *NewType = replaceVectorWithArray(OrigType, Ctx); if (OrigType != NewType) { // Create a new global variable with the updated type @@ -251,6 +233,8 @@ static void findAndReplaceVectors(Module &M) { // So instead we will use the visitor pattern Impl.GlobalMap[&G] = NewGlobal; for (User *U : G.users()) { + // Note: The GEPS are stored as constExprs + // This step flattens them out to instructions if (isa(U) && isa(U)) { ConstantExpr *CE = cast(U); convertUsersOfConstantsToInstructions(CE, diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.h b/llvm/lib/Target/DirectX/DXILDataScalarization.h index d06119397ddb2..b6c59f7b33fd4 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.h +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.h @@ -1,5 +1,5 @@ //===- DXILDataScalarization.h - Prepare LLVM Module for DXIL Data -//Legalization----===// +// Legalization----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -15,7 +15,7 @@ namespace llvm { -/// A pass thattransforms Vectors to Arrays +/// A pass that transforms Vectors to Arrays class DXILDataScalarization : public PassInfoMixin { public: PreservedAnalyses run(Module &M, ModuleAnalysisManager &); From 2e21e3d2ea8df7bf2f5cbfa817268acb1e6f8152 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Wed, 25 Sep 2024 03:43:26 -0400 Subject: [PATCH 3/5] check for ConstantDataVector --- llvm/lib/Target/DirectX/DXILDataScalarization.cpp | 15 ++++++++------- llvm/test/CodeGen/DirectX/scalar-load.ll | 7 +++++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp index b586e758c473f..0557b21930b7f 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -163,14 +163,15 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, // Handle vector to array transformation if (isa(OrigType) && isa(NewType)) { // Convert vector initializer to array initializer - auto *VecInit = dyn_cast(Init); - if (!VecInit) { - llvm_unreachable("Expected a ConstantVector for vector initializer!"); - } - SmallVector ArrayElements; - for (unsigned I = 0; I < VecInit->getNumOperands(); ++I) { - ArrayElements.push_back(VecInit->getOperand(I)); + if( ConstantVector *ConstVecInit = dyn_cast(Init)) { + for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I) + ArrayElements.push_back(ConstVecInit->getOperand(I)); + } else if (ConstantDataVector *ConstDataVecInit = llvm::dyn_cast(Init)) { + for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I) + ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I)); + } else { + llvm_unreachable("Expected a ConstantVector or ConstantDataVector for vector initializer!"); } return ConstantArray::get(cast(NewType), ArrayElements); diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll index bd99b63883e9f..fa1f1290867b2 100644 --- a/llvm/test/CodeGen/DirectX/scalar-load.ll +++ b/llvm/test/CodeGen/DirectX/scalar-load.ll @@ -2,11 +2,14 @@ ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s @"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16 @"vecData" = external addrspace(3) global <4 x i32>, align 4 -@staticArrayOfVecData = internal global [3 x <4 x i32>] zeroinitializer, align 4 +;@staticArrayOfVecData = internal global [3 x <4 x i32>] zeroinitializer, align 4 +@staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> , <4 x i32> , <4 x i32> ], align 4 + ; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16 ; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4 -; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] zeroinitializer, align 4 +; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4 + ; CHECK-NOT: @arrayofVecData ; CHECK-NOT: @vecData ; CHECK-NOT: @staticArrayOfVecData From 6f6e42c881ba1be6387f953300b6488f2fa1f029 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Wed, 25 Sep 2024 15:08:27 -0400 Subject: [PATCH 4/5] fix iterator stability issue --- .../Target/DirectX/DXILDataScalarization.cpp | 53 +++++++------------ llvm/test/CodeGen/DirectX/llc-pipeline.ll | 1 + llvm/test/CodeGen/DirectX/scalar-load.ll | 5 +- llvm/test/CodeGen/DirectX/scalar-store.ll | 2 +- 4 files changed, 25 insertions(+), 36 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp index 0557b21930b7f..9fa39a5d71d86 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -1,15 +1,15 @@ -//===- DXILDataScalarization.cpp - Prepare LLVM Module for DXIL Data -// Legalization----===// +//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===--------------------------------------------------------------------------------===// +//===----------------------------------------------------------------===// #include "DXILDataScalarization.h" #include "DirectX.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" @@ -20,7 +20,6 @@ #include "llvm/IR/Type.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" -#include #define DEBUG_TYPE "dxil-data-scalarization" #define Max_VEC_SIZE 4 @@ -164,14 +163,16 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, if (isa(OrigType) && isa(NewType)) { // Convert vector initializer to array initializer SmallVector ArrayElements; - if( ConstantVector *ConstVecInit = dyn_cast(Init)) { - for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I) - ArrayElements.push_back(ConstVecInit->getOperand(I)); - } else if (ConstantDataVector *ConstDataVecInit = llvm::dyn_cast(Init)) { - for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I) - ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I)); - } else { - llvm_unreachable("Expected a ConstantVector or ConstantDataVector for vector initializer!"); + if (ConstantVector *ConstVecInit = dyn_cast(Init)) { + for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I) + ArrayElements.push_back(ConstVecInit->getOperand(I)); + } else if (ConstantDataVector *ConstDataVecInit = + llvm::dyn_cast(Init)) { + for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I) + ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I)); + } else { + llvm_unreachable("Expected a ConstantVector or ConstantDataVector for " + "vector initializer!"); } return ConstantArray::get(cast(NewType), ArrayElements); @@ -213,9 +214,10 @@ static void findAndReplaceVectors(Module &M) { // Create a new global variable with the updated type GlobalVariable *NewGlobal = new GlobalVariable( M, NewType, G.isConstant(), G.getLinkage(), - // This is set via: transformInitializer - nullptr, G.getName() + ".scalarized", &G, G.getThreadLocalMode(), - G.getAddressSpace(), G.isExternallyInitialized()); + // Initializer is set via transformInitializer + /*Initializer=*/nullptr, G.getName() + ".scalarized", &G, + G.getThreadLocalMode(), G.getAddressSpace(), + G.isExternallyInitialized()); // Copy relevant attributes NewGlobal->setUnnamedAddr(G.getUnnamedAddr()); @@ -230,12 +232,9 @@ static void findAndReplaceVectors(Module &M) { } // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes - // type equality - // So instead we will use the visitor pattern + // type equality. Instead we will use the visitor pattern. Impl.GlobalMap[&G] = NewGlobal; - for (User *U : G.users()) { - // Note: The GEPS are stored as constExprs - // This step flattens them out to instructions + for (User *U : make_early_inc_range(G.users())) { if (isa(U) && isa(U)) { ConstantExpr *CE = cast(U); convertUsersOfConstantsToInstructions(CE, @@ -243,18 +242,6 @@ static void findAndReplaceVectors(Module &M) { /*RemoveDeadConstants=*/false, /*IncludeSelf=*/true); } - } - // Uses should have grown - std::vector UsersToProcess; - // Collect all users first - // work around so I can delete - // in a loop body - for (User *U : G.users()) { - UsersToProcess.push_back(U); - } - - // Now process each user - for (User *U : UsersToProcess) { if (isa(U)) { Instruction *Inst = cast(U); Function *F = Inst->getFunction(); @@ -294,4 +281,4 @@ INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE, ModulePass *llvm::createDXILDataScalarizationLegacyPass() { return new DXILDataScalarizationLegacy(); -} \ No newline at end of file +} diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll index 46326d6917587..102748508b4ad 100644 --- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll +++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll @@ -8,6 +8,7 @@ ; CHECK-NEXT: Target Transform Information ; CHECK-NEXT: ModulePass Manager ; CHECK-NEXT: DXIL Intrinsic Expansion +; CHECK-NEXT: DXIL Data Scalarization ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: Scalarize vector operations diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll index fa1f1290867b2..1f4834ebfd04f 100644 --- a/llvm/test/CodeGen/DirectX/scalar-load.ll +++ b/llvm/test/CodeGen/DirectX/scalar-load.ll @@ -2,17 +2,18 @@ ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s @"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16 @"vecData" = external addrspace(3) global <4 x i32>, align 4 -;@staticArrayOfVecData = internal global [3 x <4 x i32>] zeroinitializer, align 4 @staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> , <4 x i32> , <4 x i32> ], align 4 - +@staticArray = internal global [4 x i32] [i32 1, i32 2, i32 3, i32 4], align 4 ; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16 ; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4 ; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4 +; Check @staticArray ; CHECK-NOT: @arrayofVecData ; CHECK-NOT: @vecData ; CHECK-NOT: @staticArrayOfVecData +; CHECK-NOT: @staticArray.scalarized ; CHECK-LABEL: load_array_vec_test define <4 x i32> @load_array_vec_test() { diff --git a/llvm/test/CodeGen/DirectX/scalar-store.ll b/llvm/test/CodeGen/DirectX/scalar-store.ll index 767d2e47c3e8e..aac4711c3f97f 100644 --- a/llvm/test/CodeGen/DirectX/scalar-store.ll +++ b/llvm/test/CodeGen/DirectX/scalar-store.ll @@ -24,4 +24,4 @@ define void @store_vec_test(<4 x i32> %inputVec) { ; CHECK-NOT: store i32 %inputVec.{{.*}}, ptr addrspace(3) store <4 x i32> %inputVec, <4 x i32> addrspace(3)* @"vecData", align 4 ret void -} \ No newline at end of file +} From 72baf7242e14fd83886021916cef5702a6b7a33d Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Wed, 25 Sep 2024 21:21:22 -0400 Subject: [PATCH 5/5] address pr comments --- .../Target/DirectX/DXILDataScalarization.cpp | 106 ++++++++++-------- .../Target/DirectX/DXILDataScalarization.h | 16 +-- llvm/test/CodeGen/DirectX/scalar-data.ll | 12 ++ llvm/test/CodeGen/DirectX/scalar-load.ll | 20 +++- llvm/test/CodeGen/DirectX/scalar-store.ll | 2 + 5 files changed, 95 insertions(+), 61 deletions(-) create mode 100644 llvm/test/CodeGen/DirectX/scalar-data.ll diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp index 9fa39a5d71d86..0e6cf59e25750 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -1,15 +1,16 @@ -//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization----===// +//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===----------------------------------------------------------------===// +//===---------------------------------------------------------------------===// #include "DXILDataScalarization.h" #include "DirectX.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/DXILResource.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" @@ -22,11 +23,21 @@ #include "llvm/Transforms/Utils/Local.h" #define DEBUG_TYPE "dxil-data-scalarization" -#define Max_VEC_SIZE 4 +static const int MaxVecSize = 4; using namespace llvm; -static void findAndReplaceVectors(Module &M); +class DXILDataScalarizationLegacy : public ModulePass { + +public: + bool runOnModule(Module &M) override; + DXILDataScalarizationLegacy() : ModulePass(ID) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override; + static char ID; // Pass identification. +}; + +static bool findAndReplaceVectors(Module &M); class DataScalarizerVisitor : public InstVisitor { public: @@ -51,10 +62,10 @@ class DataScalarizerVisitor : public InstVisitor { bool visitStoreInst(StoreInst &SI); bool visitCallInst(CallInst &ICI) { return false; } bool visitFreezeInst(FreezeInst &FI) { return false; } - friend void findAndReplaceVectors(llvm::Module &M); + friend bool findAndReplaceVectors(llvm::Module &M); private: - GlobalVariable *getNewGlobalIfExists(Value *CurrOperand); + GlobalVariable *lookupReplacementGlobal(Value *CurrOperand); DenseMap GlobalMap; SmallVector PotentiallyDeadInstrs; bool finish(); @@ -81,7 +92,7 @@ bool DataScalarizerVisitor::finish() { } GlobalVariable * -DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) { +DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) { if (GlobalVariable *OldGlobal = dyn_cast(CurrOperand)) { auto It = GlobalMap.find(OldGlobal); if (It != GlobalMap.end()) { @@ -92,20 +103,20 @@ DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) { } bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) { - for (unsigned I = 0; I < LI.getNumOperands(); ++I) { + unsigned NumOperands = LI.getNumOperands(); + for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = LI.getOperand(I); - GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand); - if (NewGlobal) + if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) LI.setOperand(I, NewGlobal); } return false; } bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) { - for (unsigned I = 0; I < SI.getNumOperands(); ++I) { + unsigned NumOperands = SI.getNumOperands(); + for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = SI.getOperand(I); - GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand); - if (NewGlobal) { + if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) { SI.setOperand(I, NewGlobal); } } @@ -113,22 +124,23 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) { } bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { - for (unsigned I = 0; I < GEPI.getNumOperands(); ++I) { + unsigned NumOperands = GEPI.getNumOperands(); + for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = GEPI.getOperand(I); - GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand); - if (NewGlobal) { - IRBuilder<> Builder(&GEPI); + GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand); + if (!NewGlobal) + continue; + IRBuilder<> Builder(&GEPI); - SmallVector Indices; - for (auto &Index : GEPI.indices()) - Indices.push_back(Index); + SmallVector Indices; + for (auto &Index : GEPI.indices()) + Indices.push_back(Index); - Value *NewGEP = - Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices); + Value *NewGEP = + Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices); - GEPI.replaceAllUsesWith(NewGEP); - PotentiallyDeadInstrs.emplace_back(&GEPI); - } + GEPI.replaceAllUsesWith(NewGEP); + PotentiallyDeadInstrs.emplace_back(&GEPI); } return true; } @@ -137,7 +149,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) { if (auto *VecTy = dyn_cast(T)) return ArrayType::get(VecTy->getElementType(), - cast(VecTy)->getNumElements()); + dyn_cast(VecTy)->getNumElements()); if (auto *ArrayTy = dyn_cast(T)) { Type *NewElementType = replaceVectorWithArray(ArrayTy->getElementType(), Ctx); @@ -162,7 +174,7 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, // Handle vector to array transformation if (isa(OrigType) && isa(NewType)) { // Convert vector initializer to array initializer - SmallVector ArrayElements; + SmallVector ArrayElements; if (ConstantVector *ConstVecInit = dyn_cast(Init)) { for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I) ArrayElements.push_back(ConstVecInit->getOperand(I)); @@ -171,8 +183,8 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I) ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I)); } else { - llvm_unreachable("Expected a ConstantVector or ConstantDataVector for " - "vector initializer!"); + assert(false && "Expected a ConstantVector or ConstantDataVector for " + "vector initializer!"); } return ConstantArray::get(cast(NewType), ArrayElements); @@ -180,13 +192,10 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, // Handle array of vectors transformation if (auto *ArrayTy = dyn_cast(OrigType)) { - auto *ArrayInit = dyn_cast(Init); - if (!ArrayInit) { - llvm_unreachable("Expected a ConstantArray for array initializer!"); - } + assert(ArrayInit && "Expected a ConstantArray for array initializer!"); - SmallVector NewArrayElements; + SmallVector NewArrayElements; for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) { // Recursively transform array elements Constant *NewElemInit = transformInitializer( @@ -202,7 +211,8 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, return Init; } -static void findAndReplaceVectors(Module &M) { +static bool findAndReplaceVectors(Module &M) { + bool MadeChange = false; LLVMContext &Ctx = M.getContext(); IRBuilder<> Builder(Ctx); DataScalarizerVisitor Impl; @@ -212,9 +222,9 @@ static void findAndReplaceVectors(Module &M) { Type *NewType = replaceVectorWithArray(OrigType, Ctx); if (OrigType != NewType) { // Create a new global variable with the updated type + // Note: Initializer is set via transformInitializer GlobalVariable *NewGlobal = new GlobalVariable( M, NewType, G.isConstant(), G.getLinkage(), - // Initializer is set via transformInitializer /*Initializer=*/nullptr, G.getName() + ".scalarized", &G, G.getThreadLocalMode(), G.getAddressSpace(), G.isExternallyInitialized()); @@ -222,7 +232,7 @@ static void findAndReplaceVectors(Module &M) { // Copy relevant attributes NewGlobal->setUnnamedAddr(G.getUnnamedAddr()); if (G.getAlignment() > 0) { - NewGlobal->setAlignment(Align(G.getAlignment())); + NewGlobal->setAlignment(G.getAlign()); } if (G.hasInitializer()) { @@ -253,24 +263,30 @@ static void findAndReplaceVectors(Module &M) { } // Remove the old globals after the iteration - for (auto Pair : Impl.GlobalMap) { - GlobalVariable *OldG = Pair.getFirst(); - OldG->eraseFromParent(); + for (auto &[Old, New] : Impl.GlobalMap) { + Old->eraseFromParent(); + MadeChange = true; } + return MadeChange; } PreservedAnalyses DXILDataScalarization::run(Module &M, ModuleAnalysisManager &) { - findAndReplaceVectors(M); - return PreservedAnalyses::none(); + bool MadeChanges = findAndReplaceVectors(M); + if (!MadeChanges) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve(); + return PA; } bool DXILDataScalarizationLegacy::runOnModule(Module &M) { - findAndReplaceVectors(M); - return true; + return findAndReplaceVectors(M); } -void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const {} +void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addPreserved(); +} char DXILDataScalarizationLegacy::ID = 0; diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.h b/llvm/lib/Target/DirectX/DXILDataScalarization.h index b6c59f7b33fd4..560e061db96d0 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.h +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.h @@ -1,11 +1,11 @@ -//===- DXILDataScalarization.h - Prepare LLVM Module for DXIL Data -// Legalization----===// +//===- DXILDataScalarization.h - Perform DXIL Data Legalization -*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===------------------------------------------------------------------------------===// +//===---------------------------------------------------------------------===// + #ifndef LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H #define LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H @@ -20,16 +20,6 @@ class DXILDataScalarization : public PassInfoMixin { public: PreservedAnalyses run(Module &M, ModuleAnalysisManager &); }; - -class DXILDataScalarizationLegacy : public ModulePass { - -public: - bool runOnModule(Module &M) override; - DXILDataScalarizationLegacy() : ModulePass(ID) {} - - void getAnalysisUsage(AnalysisUsage &AU) const override; - static char ID; // Pass identification. -}; } // namespace llvm #endif // LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H diff --git a/llvm/test/CodeGen/DirectX/scalar-data.ll b/llvm/test/CodeGen/DirectX/scalar-data.ll new file mode 100644 index 0000000000000..4438604a3a879 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/scalar-data.ll @@ -0,0 +1,12 @@ +; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s +; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s + +; Make sure we don't touch arrays without vectors and that can recurse multiple-dimension arrays of vectors + +@staticArray = internal global [4 x i32] [i32 1, i32 2, i32 3, i32 4], align 4 +@"groushared3dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x <4 x i32>]]] zeroinitializer, align 16 + +; CHECK @staticArray +; CHECK-NOT: @staticArray.scalarized +; CHECK: @groushared3dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x [4 x i32]]]] zeroinitializer, align 16 +; CHECK-NOT: @groushared3dArrayofVectors diff --git a/llvm/test/CodeGen/DirectX/scalar-load.ll b/llvm/test/CodeGen/DirectX/scalar-load.ll index 1f4834ebfd04f..11678f48a5e01 100644 --- a/llvm/test/CodeGen/DirectX/scalar-load.ll +++ b/llvm/test/CodeGen/DirectX/scalar-load.ll @@ -1,19 +1,23 @@ ; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s + +; Make sure we can load groupshared, static vectors and arrays of vectors + @"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16 @"vecData" = external addrspace(3) global <4 x i32>, align 4 @staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> , <4 x i32> , <4 x i32> ], align 4 -@staticArray = internal global [4 x i32] [i32 1, i32 2, i32 3, i32 4], align 4 +@"groushared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [ 3 x <4 x i32>]] zeroinitializer, align 16 ; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16 ; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4 ; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4 -; Check @staticArray +; CHECK: @groushared2dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [4 x i32]]] zeroinitializer, align 16 ; CHECK-NOT: @arrayofVecData ; CHECK-NOT: @vecData ; CHECK-NOT: @staticArrayOfVecData -; CHECK-NOT: @staticArray.scalarized +; CHECK-NOT: @groushared2dArrayofVectors + ; CHECK-LABEL: load_array_vec_test define <4 x i32> @load_array_vec_test() { @@ -42,3 +46,13 @@ define <4 x i32> @load_static_array_of_vec_test(i32 %index) { %4 = load <4 x i32>, <4 x i32>* %3, align 4 ret <4 x i32> %4 } + +; CHECK-LABEL: multid_load_test +define <4 x i32> @multid_load_test() { + ; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@groushared2dArrayofVectors.scalarized.*|%.*)}}, align 4 + ; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4 + %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 0, i32 0), align 4 + %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 1, i32 1), align 4 + %3 = add <4 x i32> %1, %2 + ret <4 x i32> %3 +} diff --git a/llvm/test/CodeGen/DirectX/scalar-store.ll b/llvm/test/CodeGen/DirectX/scalar-store.ll index aac4711c3f97f..08d8a2c57c6c3 100644 --- a/llvm/test/CodeGen/DirectX/scalar-store.ll +++ b/llvm/test/CodeGen/DirectX/scalar-store.ll @@ -1,6 +1,8 @@ ; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s ; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s +; Make sure we can store groupshared, static vectors and arrays of vectors + @"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16 @"vecData" = external addrspace(3) global <4 x i32>, align 4