Skip to content

Commit d78ffd2

Browse files
committed
[DirectX] Lower @llvm.dx.typedBufferLoad to DXIL ops
The `@llvm.dx.typedBufferLoad` intrinsic is lowered to `@dx.op.bufferLoad`. There's some complexity here due to translating from a vector return type to a named struct and trying to avoid excessive IR coming out of that. Note that this change includes a bit of a hack in how it deals with `getOverloadKind` for the `dx.ResRet` types - we need to adjust how we deal with operation overloads to generate a table directly rather than proxy through the OverloadKind enum, but that's left for a later change here. Pull Request: llvm#104252
1 parent c0d8b67 commit d78ffd2

File tree

7 files changed

+210
-8
lines changed

7 files changed

+210
-8
lines changed

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def int_dx_handle_fromBinding
3030
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
3131
[IntrNoMem]>;
3232

33+
def int_dx_typedBufferLoad
34+
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
35+
[llvm_any_ty, llvm_i32_ty]>;
36+
3337
// Cast between target extension handle types and dxil-style opaque handles
3438
def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
3539

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def Int64Ty : DXILOpParamType;
4040
def HalfTy : DXILOpParamType;
4141
def FloatTy : DXILOpParamType;
4242
def DoubleTy : DXILOpParamType;
43-
def ResRetTy : DXILOpParamType;
43+
def ResRetHalfTy : DXILOpParamType;
44+
def ResRetFloatTy : DXILOpParamType;
45+
def ResRetInt16Ty : DXILOpParamType;
46+
def ResRetInt32Ty : DXILOpParamType;
4447
def HandleTy : DXILOpParamType;
4548
def ResBindTy : DXILOpParamType;
4649
def ResPropsTy : DXILOpParamType;
@@ -683,6 +686,17 @@ def CreateHandle : DXILOp<57, createHandle> {
683686
let stages = [Stages<DXIL1_0, [all_stages]>];
684687
}
685688

689+
def BufferLoad : DXILOp<68, bufferLoad> {
690+
let Doc = "reads from a TypedBuffer";
691+
// Handle, Coord0, Coord1
692+
let arguments = [HandleTy, Int32Ty, Int32Ty];
693+
let result = OverloadTy;
694+
let overloads =
695+
[Overloads<DXIL1_0,
696+
[ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
697+
let stages = [Stages<DXIL1_0, [all_stages]>];
698+
}
699+
686700
def ThreadId : DXILOp<93, threadId> {
687701
let Doc = "Reads the thread ID";
688702
let LLVMIntrinsic = int_dx_thread_id;

llvm/lib/Target/DirectX/DXILOpBuilder.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,15 @@ static OverloadKind getOverloadKind(Type *Ty) {
120120
}
121121
case Type::PointerTyID:
122122
return OverloadKind::UserDefineType;
123-
case Type::StructTyID:
123+
case Type::StructTyID: {
124+
// TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
125+
// how we're handling overloads and remove the `OverloadKind` proxy enum.
126+
StructType *ST = cast<StructType>(Ty);
127+
if (ST->hasName() && ST->getName().starts_with("dx.types.ResRet"))
128+
return getOverloadKind(ST->getElementType(0));
129+
124130
return OverloadKind::ObjectType;
131+
}
125132
default:
126133
llvm_unreachable("invalid overload type");
127134
return OverloadKind::VOID;
@@ -195,10 +202,11 @@ static StructType *getOrCreateStructType(StringRef Name,
195202
return StructType::create(Ctx, EltTys, Name);
196203
}
197204

198-
static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
199-
OverloadKind Kind = getOverloadKind(OverloadTy);
205+
static StructType *getResRetType(Type *ElementTy) {
206+
LLVMContext &Ctx = ElementTy->getContext();
207+
OverloadKind Kind = getOverloadKind(ElementTy);
200208
std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
201-
Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
209+
Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy,
202210
Type::getInt32Ty(Ctx)};
203211
return getOrCreateStructType(TypeName, FieldTypes, Ctx);
204212
}
@@ -248,8 +256,14 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
248256
return Type::getInt64Ty(Ctx);
249257
case OpParamType::OverloadTy:
250258
return OverloadTy;
251-
case OpParamType::ResRetTy:
252-
return getResRetType(OverloadTy, Ctx);
259+
case OpParamType::ResRetHalfTy:
260+
return getResRetType(Type::getHalfTy(Ctx));
261+
case OpParamType::ResRetFloatTy:
262+
return getResRetType(Type::getFloatTy(Ctx));
263+
case OpParamType::ResRetInt16Ty:
264+
return getResRetType(Type::getInt16Ty(Ctx));
265+
case OpParamType::ResRetInt32Ty:
266+
return getResRetType(Type::getInt32Ty(Ctx));
253267
case OpParamType::HandleTy:
254268
return getHandleType(Ctx);
255269
case OpParamType::ResBindTy:
@@ -391,6 +405,7 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
391405
return makeOpError(OpCode, "Wrong number of arguments");
392406
OverloadTy = Args[ArgIndex]->getType();
393407
}
408+
394409
FunctionType *DXILOpFT =
395410
getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy);
396411

@@ -451,6 +466,10 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
451466
return *Result;
452467
}
453468

469+
StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
470+
return ::getResRetType(ElementTy);
471+
}
472+
454473
StructType *DXILOpBuilder::getHandleType() {
455474
return ::getHandleType(IRB.getContext());
456475
}

llvm/lib/Target/DirectX/DXILOpBuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class DXILOpBuilder {
4646
Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args,
4747
Type *RetTy = nullptr);
4848

49+
/// Get a `%dx.types.ResRet` type with the given element type.
50+
StructType *getResRetType(Type *ElementTy);
4951
/// Get the `%dx.types.Handle` type.
5052
StructType *getHandleType();
5153

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,59 @@ class OpLowerer {
236236
lowerToBindAndAnnotateHandle(F);
237237
}
238238

239+
void lowerTypedBufferLoad(Function &F) {
240+
IRBuilder<> &IRB = OpBuilder.getIRB();
241+
Type *Int32Ty = IRB.getInt32Ty();
242+
243+
replaceFunction(F, [&](CallInst *CI) -> Error {
244+
IRB.SetInsertPoint(CI);
245+
246+
Value *Handle =
247+
createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
248+
Value *Index0 = CI->getArgOperand(1);
249+
Value *Index1 = UndefValue::get(Int32Ty);
250+
Type *RetTy = OpBuilder.getResRetType(CI->getType()->getScalarType());
251+
252+
std::array<Value *, 3> Args{Handle, Index0, Index1};
253+
Expected<CallInst *> OpCall =
254+
OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, RetTy);
255+
if (Error E = OpCall.takeError())
256+
return E;
257+
258+
std::array<Value *, 4> Extracts = {};
259+
260+
// We've switched the return type from a vector to a struct, but at this
261+
// point most vectors have probably already been scalarized. Try to
262+
// forward arguments directly rather than inserting into and immediately
263+
// extracting from a vector.
264+
for (Use &U : make_early_inc_range(CI->uses()))
265+
if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser()))
266+
if (auto *Index = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
267+
size_t IndexVal = Index->getZExtValue();
268+
assert(IndexVal < 4 && "Index into buffer load out of range");
269+
if (!Extracts[IndexVal])
270+
Extracts[IndexVal] = IRB.CreateExtractValue(*OpCall, IndexVal);
271+
EEI->replaceAllUsesWith(Extracts[IndexVal]);
272+
EEI->eraseFromParent();
273+
}
274+
275+
// If there are still uses then we need to create a vector.
276+
if (!CI->use_empty()) {
277+
for (int I = 0, E = 4; I != E; ++I)
278+
if (!Extracts[I])
279+
Extracts[I] = IRB.CreateExtractValue(*OpCall, I);
280+
281+
Value *Vec = UndefValue::get(CI->getType());
282+
for (int I = 0, E = 4; I != E; ++I)
283+
Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
284+
CI->replaceAllUsesWith(Vec);
285+
}
286+
287+
CI->eraseFromParent();
288+
return Error::success();
289+
});
290+
}
291+
239292
bool lowerIntrinsics() {
240293
bool Updated = false;
241294

@@ -253,6 +306,10 @@ class OpLowerer {
253306
#include "DXILOperation.inc"
254307
case Intrinsic::dx_handle_fromBinding:
255308
lowerHandleFromBinding(F);
309+
break;
310+
case Intrinsic::dx_typedBufferLoad:
311+
lowerTypedBufferLoad(F);
312+
break;
256313
}
257314
Updated = true;
258315
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
; RUN: opt -S -dxil-op-lower %s | FileCheck %s
2+
3+
target triple = "dxil-pc-shadermodel6.6-compute"
4+
5+
declare void @scalar_user(float)
6+
declare void @vector_user(<4 x float>)
7+
8+
define void @loadfloats() {
9+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
10+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
11+
%buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
12+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
13+
i32 0, i32 0, i32 1, i32 0, i1 false)
14+
15+
; The temporary casts should all have been cleaned up
16+
; CHECK-NOT: %dx.cast_handle
17+
18+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
19+
%data0 = call <4 x float> @llvm.dx.typedBufferLoad(
20+
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
21+
22+
; The extract order depends on the users, so don't enforce that here.
23+
; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
24+
%data0_0 = extractelement <4 x float> %data0, i32 0
25+
; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 2
26+
%data0_2 = extractelement <4 x float> %data0, i32 2
27+
28+
; If all of the uses are extracts, we skip creating a vector
29+
; CHECK-NOT: insertelement
30+
call void @scalar_user(float %data0_0)
31+
call void @scalar_user(float %data0_2)
32+
33+
; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef)
34+
%data4 = call <4 x float> @llvm.dx.typedBufferLoad(
35+
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4)
36+
37+
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0
38+
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1
39+
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2
40+
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3
41+
; CHECK: insertelement <4 x float> undef
42+
; CHECK: insertelement <4 x float>
43+
; CHECK: insertelement <4 x float>
44+
; CHECK: insertelement <4 x float>
45+
call void @vector_user(<4 x float> %data4)
46+
47+
; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef)
48+
%data12 = call <4 x float> @llvm.dx.typedBufferLoad(
49+
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12)
50+
51+
; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3
52+
%data12_3 = extractelement <4 x float> %data12, i32 3
53+
54+
; If there are a mix of users we need the vector, but extracts are direct
55+
; CHECK: call void @scalar_user(float [[DATA12_3]])
56+
call void @scalar_user(float %data12_3)
57+
call void @vector_user(<4 x float> %data12)
58+
59+
ret void
60+
}
61+
62+
define void @loadint() {
63+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
64+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
65+
%buffer = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0)
66+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0(
67+
i32 0, i32 0, i32 1, i32 0, i1 false)
68+
69+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
70+
%data0 = call <4 x i32> @llvm.dx.typedBufferLoad(
71+
target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0)
72+
73+
ret void
74+
}
75+
76+
define void @loadhalf() {
77+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
78+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
79+
%buffer = call target("dx.TypedBuffer", <4 x half>, 0, 0, 0)
80+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_0_0_0(
81+
i32 0, i32 0, i32 1, i32 0, i1 false)
82+
83+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
84+
%data0 = call <4 x half> @llvm.dx.typedBufferLoad(
85+
target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0)
86+
87+
ret void
88+
}
89+
90+
define void @loadi16() {
91+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
92+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
93+
%buffer = call target("dx.TypedBuffer", <4 x i16>, 0, 0, 0)
94+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_0_0_0(
95+
i32 0, i32 0, i32 1, i32 0, i1 false)
96+
97+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
98+
%data0 = call <4 x i16> @llvm.dx.typedBufferLoad(
99+
target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0)
100+
101+
ret void
102+
}

llvm/utils/TableGen/DXILEmitter.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,11 @@ static StringRef getOverloadKindStr(const Record *R) {
187187
.Case("Int8Ty", "OverloadKind::I8")
188188
.Case("Int16Ty", "OverloadKind::I16")
189189
.Case("Int32Ty", "OverloadKind::I32")
190-
.Case("Int64Ty", "OverloadKind::I64");
190+
.Case("Int64Ty", "OverloadKind::I64")
191+
.Case("ResRetHalfTy", "OverloadKind::HALF")
192+
.Case("ResRetFloatTy", "OverloadKind::FLOAT")
193+
.Case("ResRetInt16Ty", "OverloadKind::I16")
194+
.Case("ResRetInt32Ty", "OverloadKind::I32");
191195
}
192196

193197
/// Return a string representation of valid overload information denoted

0 commit comments

Comments
 (0)