Skip to content

Commit 12fb0d4

Browse files
Icohedroninbelicpow2clk
authored
[DirectX] Legalize memcpy (llvm#139173)
Fixes llvm#137188 This PR legalizes memcpy for DXIL in cases where: - the src and dst arguments are from Alloca or a GlobalVariable, - the src and dst are pointers to an ArrayType, - the array element types of src and dst must be equivalent, and - the len param is a ConstantInt These assumptions simplify the legalization and, with the addition of llvm#138991, covers the currently-known cases of memcpy that appear when compiling DML shaders. This PR may be unnecessary if llvm#138788 determines that memset and memcpy can be eliminated entirely. --------- Co-authored-by: Finn Plummer <[email protected]> Co-authored-by: Greg Roth <[email protected]>
1 parent 5c37840 commit 12fb0d4

File tree

2 files changed

+243
-0
lines changed

2 files changed

+243
-0
lines changed

llvm/lib/Target/DirectX/DXILLegalizePass.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,67 @@ downcastI64toI32InsertExtractElements(Instruction &I,
246246
}
247247
}
248248

249+
static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src,
250+
ConstantInt *Length) {
251+
252+
uint64_t ByteLength = Length->getZExtValue();
253+
// If length to copy is zero, no memcpy is needed.
254+
if (ByteLength == 0)
255+
return;
256+
257+
LLVMContext &Ctx = Builder.getContext();
258+
const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout();
259+
260+
auto GetArrTyFromVal = [](Value *Val) -> ArrayType * {
261+
assert(isa<AllocaInst>(Val) ||
262+
isa<GlobalVariable>(Val) &&
263+
"Expected Val to be an Alloca or Global Variable");
264+
if (auto *Alloca = dyn_cast<AllocaInst>(Val))
265+
return dyn_cast<ArrayType>(Alloca->getAllocatedType());
266+
if (auto *GlobalVar = dyn_cast<GlobalVariable>(Val))
267+
return dyn_cast<ArrayType>(GlobalVar->getValueType());
268+
return nullptr;
269+
};
270+
271+
ArrayType *DstArrTy = GetArrTyFromVal(Dst);
272+
assert(DstArrTy && "Expected Dst of memcpy to be a Pointer to an Array Type");
273+
if (auto *DstGlobalVar = dyn_cast<GlobalVariable>(Dst))
274+
assert(!DstGlobalVar->isConstant() &&
275+
"The Dst of memcpy must not be a constant Global Variable");
276+
[[maybe_unused]] ArrayType *SrcArrTy = GetArrTyFromVal(Src);
277+
assert(SrcArrTy && "Expected Src of memcpy to be a Pointer to an Array Type");
278+
279+
Type *DstElemTy = DstArrTy->getElementType();
280+
uint64_t DstElemByteSize = DL.getTypeStoreSize(DstElemTy);
281+
assert(DstElemByteSize > 0 && "Dst element type store size must be set");
282+
Type *SrcElemTy = SrcArrTy->getElementType();
283+
[[maybe_unused]] uint64_t SrcElemByteSize = DL.getTypeStoreSize(SrcElemTy);
284+
assert(SrcElemByteSize > 0 && "Src element type store size must be set");
285+
286+
// This assumption simplifies implementation and covers currently-known
287+
// use-cases for DXIL. It may be relaxed in the future if required.
288+
assert(DstElemTy == SrcElemTy &&
289+
"The element types of Src and Dst arrays must match");
290+
291+
[[maybe_unused]] uint64_t DstArrNumElems = DstArrTy->getArrayNumElements();
292+
assert(DstElemByteSize * DstArrNumElems >= ByteLength &&
293+
"Dst array size must be at least as large as the memcpy length");
294+
[[maybe_unused]] uint64_t SrcArrNumElems = SrcArrTy->getArrayNumElements();
295+
assert(SrcElemByteSize * SrcArrNumElems >= ByteLength &&
296+
"Src array size must be at least as large as the memcpy length");
297+
298+
uint64_t NumElemsToCopy = ByteLength / DstElemByteSize;
299+
assert(ByteLength % DstElemByteSize == 0 &&
300+
"memcpy length must be divisible by array element type");
301+
for (uint64_t I = 0; I < NumElemsToCopy; ++I) {
302+
Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
303+
Value *SrcPtr = Builder.CreateInBoundsGEP(SrcElemTy, Src, Offset, "gep");
304+
Value *SrcVal = Builder.CreateLoad(SrcElemTy, SrcPtr);
305+
Value *DstPtr = Builder.CreateInBoundsGEP(DstElemTy, Dst, Offset, "gep");
306+
Builder.CreateStore(SrcVal, DstPtr);
307+
}
308+
}
309+
249310
static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
250311
ConstantInt *SizeCI,
251312
DenseMap<Value *, Value *> &ReplacedValues) {
@@ -296,6 +357,33 @@ static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
296357
}
297358
}
298359

360+
// Expands the instruction `I` into corresponding loads and stores if it is a
361+
// memcpy call. In that case, the call instruction is added to the `ToRemove`
362+
// vector. `ReplacedValues` is unused.
363+
static void legalizeMemCpy(Instruction &I,
364+
SmallVectorImpl<Instruction *> &ToRemove,
365+
DenseMap<Value *, Value *> &ReplacedValues) {
366+
367+
CallInst *CI = dyn_cast<CallInst>(&I);
368+
if (!CI)
369+
return;
370+
371+
Intrinsic::ID ID = CI->getIntrinsicID();
372+
if (ID != Intrinsic::memcpy)
373+
return;
374+
375+
IRBuilder<> Builder(&I);
376+
Value *Dst = CI->getArgOperand(0);
377+
Value *Src = CI->getArgOperand(1);
378+
ConstantInt *Length = dyn_cast<ConstantInt>(CI->getArgOperand(2));
379+
assert(Length && "Expected Length to be a ConstantInt");
380+
ConstantInt *IsVolatile = dyn_cast<ConstantInt>(CI->getArgOperand(3));
381+
assert(IsVolatile && "Expected IsVolatile to be a ConstantInt");
382+
assert(IsVolatile->getZExtValue() == 0 && "Expected IsVolatile to be false");
383+
emitMemcpyExpansion(Builder, Dst, Src, Length);
384+
ToRemove.push_back(CI);
385+
}
386+
299387
static void removeMemSet(Instruction &I,
300388
SmallVectorImpl<Instruction *> &ToRemove,
301389
DenseMap<Value *, Value *> &ReplacedValues) {
@@ -348,6 +436,7 @@ class DXILLegalizationPipeline {
348436
LegalizationPipeline.push_back(fixI8UseChain);
349437
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
350438
LegalizationPipeline.push_back(legalizeFreeze);
439+
LegalizationPipeline.push_back(legalizeMemCpy);
351440
LegalizationPipeline.push_back(removeMemSet);
352441
}
353442
};
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -dxil-legalize -dxil-finalize-linkage -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
3+
4+
define void @replace_int_memcpy_test() #0 {
5+
; CHECK-LABEL: define void @replace_int_memcpy_test(
6+
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
7+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [1 x i32], align 4
8+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [1 x i32], align 4
9+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i32 0
10+
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[GEP]], align 4
11+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 0
12+
; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP1]], align 4
13+
; CHECK-NEXT: ret void
14+
;
15+
%1 = alloca [1 x i32], align 4
16+
%2 = alloca [1 x i32], align 4
17+
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(4) %2, ptr align 4 dereferenceable(4) %1, i32 4, i1 false)
18+
ret void
19+
}
20+
21+
define void @replace_3int_memcpy_test() #0 {
22+
; CHECK-LABEL: define void @replace_3int_memcpy_test(
23+
; CHECK-SAME: ) #[[ATTR0]] {
24+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [3 x i32], align 4
25+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [3 x i32], align 4
26+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i32 0
27+
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[GEP]], align 4
28+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 0
29+
; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP1]], align 4
30+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i32 1
31+
; CHECK-NEXT: [[TMP4:%.*]] = load i32, ptr [[GEP2]], align 4
32+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 1
33+
; CHECK-NEXT: store i32 [[TMP4]], ptr [[GEP3]], align 4
34+
; CHECK-NEXT: [[GEP4:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i32 2
35+
; CHECK-NEXT: [[TMP5:%.*]] = load i32, ptr [[GEP4]], align 4
36+
; CHECK-NEXT: [[GEP5:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 2
37+
; CHECK-NEXT: store i32 [[TMP5]], ptr [[GEP5]], align 4
38+
; CHECK-NEXT: ret void
39+
;
40+
%1 = alloca [3 x i32], align 4
41+
%2 = alloca [3 x i32], align 4
42+
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(12) %2, ptr align 4 dereferenceable(12) %1, i32 12, i1 false)
43+
ret void
44+
}
45+
46+
define void @replace_mismatched_size_int_memcpy_test() #0 {
47+
; CHECK-LABEL: define void @replace_mismatched_size_int_memcpy_test(
48+
; CHECK-SAME: ) #[[ATTR0]] {
49+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x i32], align 4
50+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [3 x i32], align 4
51+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i32 0
52+
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[GEP]], align 4
53+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 0
54+
; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP1]], align 4
55+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i32 1
56+
; CHECK-NEXT: [[TMP4:%.*]] = load i32, ptr [[GEP2]], align 4
57+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 1
58+
; CHECK-NEXT: store i32 [[TMP4]], ptr [[GEP3]], align 4
59+
; CHECK-NEXT: ret void
60+
;
61+
%1 = alloca [2 x i32], align 4
62+
%2 = alloca [3 x i32], align 4
63+
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(12) %2, ptr align 4 dereferenceable(8) %1, i32 8, i1 false)
64+
ret void
65+
}
66+
67+
define void @replace_int16_memcpy_test() #0 {
68+
; CHECK-LABEL: define void @replace_int16_memcpy_test(
69+
; CHECK-SAME: ) #[[ATTR0]] {
70+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x i16], align 2
71+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x i16], align 2
72+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i16, ptr [[TMP1]], i32 0
73+
; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr [[GEP]], align 2
74+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i16, ptr [[TMP2]], i32 0
75+
; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP1]], align 2
76+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i16, ptr [[TMP1]], i32 1
77+
; CHECK-NEXT: [[TMP4:%.*]] = load i16, ptr [[GEP2]], align 2
78+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds i16, ptr [[TMP2]], i32 1
79+
; CHECK-NEXT: store i16 [[TMP4]], ptr [[GEP3]], align 2
80+
; CHECK-NEXT: ret void
81+
;
82+
%1 = alloca [2 x i16], align 2
83+
%2 = alloca [2 x i16], align 2
84+
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 2 dereferenceable(4) %2, ptr align 2 dereferenceable(4) %1, i32 4, i1 false)
85+
ret void
86+
}
87+
88+
define void @replace_float_memcpy_test() #0 {
89+
; CHECK-LABEL: define void @replace_float_memcpy_test(
90+
; CHECK-SAME: ) #[[ATTR0]] {
91+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x float], align 4
92+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x float], align 4
93+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds float, ptr [[TMP1]], i32 0
94+
; CHECK-NEXT: [[TMP3:%.*]] = load float, ptr [[GEP]], align 4
95+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds float, ptr [[TMP2]], i32 0
96+
; CHECK-NEXT: store float [[TMP3]], ptr [[GEP1]], align 4
97+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds float, ptr [[TMP1]], i32 1
98+
; CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[GEP2]], align 4
99+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds float, ptr [[TMP2]], i32 1
100+
; CHECK-NEXT: store float [[TMP4]], ptr [[GEP3]], align 4
101+
; CHECK-NEXT: ret void
102+
;
103+
%1 = alloca [2 x float], align 4
104+
%2 = alloca [2 x float], align 4
105+
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(8) %2, ptr align 4 dereferenceable(8) %1, i32 8, i1 false)
106+
ret void
107+
}
108+
109+
define void @replace_double_memcpy_test() #0 {
110+
; CHECK-LABEL: define void @replace_double_memcpy_test(
111+
; CHECK-SAME: ) #[[ATTR0]] {
112+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x double], align 4
113+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x double], align 4
114+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds double, ptr [[TMP1]], i32 0
115+
; CHECK-NEXT: [[TMP3:%.*]] = load double, ptr [[GEP]], align 8
116+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds double, ptr [[TMP2]], i32 0
117+
; CHECK-NEXT: store double [[TMP3]], ptr [[GEP1]], align 8
118+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds double, ptr [[TMP1]], i32 1
119+
; CHECK-NEXT: [[TMP4:%.*]] = load double, ptr [[GEP2]], align 8
120+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds double, ptr [[TMP2]], i32 1
121+
; CHECK-NEXT: store double [[TMP4]], ptr [[GEP3]], align 8
122+
; CHECK-NEXT: ret void
123+
;
124+
%1 = alloca [2 x double], align 4
125+
%2 = alloca [2 x double], align 4
126+
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(8) %2, ptr align 4 dereferenceable(8) %1, i32 16, i1 false)
127+
ret void
128+
}
129+
130+
define void @replace_half_memcpy_test() #0 {
131+
; CHECK-LABEL: define void @replace_half_memcpy_test(
132+
; CHECK-SAME: ) #[[ATTR0]] {
133+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x half], align 2
134+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x half], align 2
135+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds half, ptr [[TMP1]], i32 0
136+
; CHECK-NEXT: [[TMP3:%.*]] = load half, ptr [[GEP]], align 2
137+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds half, ptr [[TMP2]], i32 0
138+
; CHECK-NEXT: store half [[TMP3]], ptr [[GEP1]], align 2
139+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds half, ptr [[TMP1]], i32 1
140+
; CHECK-NEXT: [[TMP4:%.*]] = load half, ptr [[GEP2]], align 2
141+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds half, ptr [[TMP2]], i32 1
142+
; CHECK-NEXT: store half [[TMP4]], ptr [[GEP3]], align 2
143+
; CHECK-NEXT: ret void
144+
;
145+
%1 = alloca [2 x half], align 2
146+
%2 = alloca [2 x half], align 2
147+
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 2 dereferenceable(4) %2, ptr align 2 dereferenceable(4) %1, i32 4, i1 false)
148+
ret void
149+
}
150+
151+
attributes #0 = {"hlsl.export"}
152+
153+
declare void @llvm.memcpy.p0.p2.i32(ptr noalias, ptr addrspace(2) noalias readonly, i32, i1)
154+
declare void @llvm.memcpy.p0.p0.i32(ptr noalias, ptr noalias readonly, i32, i1)

0 commit comments

Comments
 (0)