Skip to content

Commit f6dbbfb

Browse files
committed
[HLSL] Implementation of dot intrinsic
This change implements #70073 HLSL has a dot intrinsic defined here: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot The intrinsic itself is defined as a HLSL_LANG LangBuiltin in Builtins.td. This is used to associate all the dot product typdef defined hlsl_intrinsics.h with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp. In IntrinsicsDirectX.td we define the llvmIR for the dot product. A few goals were in mind for this IR. First it should operate on only vectors. Second the return type should be the vector element type. Third the second parameter vector should be of the same size as the first parameter. Finally `a dot b` should be the same as `b dot a`. In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot product though is language specific intrinsic and so is guarded behind getLangOpts().HLSL. The call chain looks like this: EmitBuiltinExpr -> EmitHLSLBuiltinExp EmitHLSLBuiltinExp dot product intrinsics makes a destinction between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply. Sema.h & SemaChecking.cpp saw the addition of CheckHLSLBuiltinFunctionCall, a language specific semantic validation that can be expanded for other hlsl specific intrinsics.
1 parent 4bf50e0 commit f6dbbfb

File tree

10 files changed

+497
-5
lines changed

10 files changed

+497
-5
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4524,6 +4524,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
45244524
let Prototype = "void*(unsigned char)";
45254525
}
45264526

4527+
def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
4528+
let Spellings = ["__builtin_hlsl_dot"];
4529+
let Attributes = [NoThrow, Const, CustomTypeChecking];
4530+
let Prototype = "void(...)";
4531+
}
4532+
45274533
// Builtins for XRay.
45284534
def XRayCustomEvent : Builtin {
45294535
let Spellings = ["__xray_customevent"];

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10267,6 +10267,8 @@ def err_vec_builtin_non_vector : Error<
1026710267
"first two arguments to %0 must be vectors">;
1026810268
def err_vec_builtin_incompatible_vector : Error<
1026910269
"first two arguments to %0 must have the same type">;
10270+
def err_vec_builtin_incompatible_size : Error<
10271+
"first two arguments to %0 must have the same size">;
1027010272
def err_vsx_builtin_nonconstant_argument : Error<
1027110273
"argument %0 to %1 must be a 2-bit unsigned literal (i.e. 0, 1, 2 or 3)">;
1027210274

clang/include/clang/Sema/Sema.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14054,6 +14054,7 @@ class Sema final {
1405414054
bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1405514055
CallExpr *TheCall);
1405614056
bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
14057+
bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
1405714058
bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum);
1405814059
bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1405914060
CallExpr *TheCall);

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "llvm/IR/IntrinsicsAMDGPU.h"
4545
#include "llvm/IR/IntrinsicsARM.h"
4646
#include "llvm/IR/IntrinsicsBPF.h"
47+
#include "llvm/IR/IntrinsicsDirectX.h"
4748
#include "llvm/IR/IntrinsicsHexagon.h"
4849
#include "llvm/IR/IntrinsicsNVPTX.h"
4950
#include "llvm/IR/IntrinsicsPowerPC.h"
@@ -5982,6 +5983,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
59825983
llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
59835984
}
59845985

5986+
// EmitHLSLBuiltinExpr will check getLangOpts().HLSL
5987+
if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
5988+
return RValue::get(V);
5989+
59855990
if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
59865991
return EmitHipStdParUnsupportedBuiltin(this, FD);
59875992

@@ -17895,6 +17900,50 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
1789517900
return Arg;
1789617901
}
1789717902

17903+
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
17904+
const CallExpr *E) {
17905+
if (!getLangOpts().HLSL)
17906+
return nullptr;
17907+
17908+
switch (BuiltinID) {
17909+
case Builtin::BI__builtin_hlsl_dot: {
17910+
Value *Op0 = EmitScalarExpr(E->getArg(0));
17911+
Value *Op1 = EmitScalarExpr(E->getArg(1));
17912+
llvm::Type *T0 = Op0->getType();
17913+
llvm::Type *T1 = Op1->getType();
17914+
if (!T0->isVectorTy() && !T1->isVectorTy()) {
17915+
if (T0->isFloatingPointTy()) {
17916+
return Builder.CreateFMul(Op0, Op1, "dx.dot");
17917+
}
17918+
17919+
if (T0->isIntegerTy()) {
17920+
return Builder.CreateMul(Op0, Op1, "dx.dot");
17921+
}
17922+
assert(
17923+
false &&
17924+
"Dot product on a scalar is only supported on integers and floats.");
17925+
}
17926+
assert(T0->isVectorTy() && T1->isVectorTy() &&
17927+
"Dot product of vector and scalar is not supported.");
17928+
17929+
// NOTE: this assert will need to be revisited after overload resoltion
17930+
// PR merges.
17931+
assert(T0->getScalarType() == T1->getScalarType() &&
17932+
"Dot product of vectors need the same element types.");
17933+
17934+
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
17935+
auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
17936+
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
17937+
"Dot product requires vectors to be of the same size.");
17938+
17939+
return Builder.CreateIntrinsic(
17940+
/*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
17941+
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
17942+
} break;
17943+
}
17944+
return nullptr;
17945+
}
17946+
1789817947
Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
1789917948
const CallExpr *E) {
1790017949
llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4405,6 +4405,7 @@ class CodeGenFunction : public CodeGenTypeCache {
44054405
llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
44064406
llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
44074407
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
4408+
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
44084409
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
44094410
const CallExpr *E);
44104411
llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,98 @@ double3 cos(double3);
179179
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
180180
double4 cos(double4);
181181

182+
//===----------------------------------------------------------------------===//
183+
// dot product builtins
184+
//===----------------------------------------------------------------------===//
185+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
186+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
187+
half dot(half, half);
188+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
189+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
190+
half dot(half2, half2);
191+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
192+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
193+
half dot(half3, half3);
194+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
195+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
196+
half dot(half4, half4);
197+
198+
#ifdef __HLSL_ENABLE_16_BIT
199+
_HLSL_AVAILABILITY(shadermodel, 6.2)
200+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
201+
int16_t dot(int16_t, int16_t);
202+
_HLSL_AVAILABILITY(shadermodel, 6.2)
203+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
204+
int16_t dot(int16_t2, int16_t2);
205+
_HLSL_AVAILABILITY(shadermodel, 6.2)
206+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
207+
int16_t dot(int16_t3, int16_t3);
208+
_HLSL_AVAILABILITY(shadermodel, 6.2)
209+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
210+
int16_t dot(int16_t4, int16_t4);
211+
212+
_HLSL_AVAILABILITY(shadermodel, 6.2)
213+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
214+
uint16_t dot(uint16_t, uint16_t);
215+
_HLSL_AVAILABILITY(shadermodel, 6.2)
216+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
217+
uint16_t dot(uint16_t2, uint16_t2);
218+
_HLSL_AVAILABILITY(shadermodel, 6.2)
219+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
220+
uint16_t dot(uint16_t3, uint16_t3);
221+
_HLSL_AVAILABILITY(shadermodel, 6.2)
222+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
223+
uint16_t dot(uint16_t4, uint16_t4);
224+
#endif
225+
226+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
227+
float dot(float, float);
228+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
229+
float dot(float2, float2);
230+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
231+
float dot(float3, float3);
232+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
233+
float dot(float4, float4);
234+
235+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
236+
double dot(double, double);
237+
238+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
239+
int dot(int, int);
240+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
241+
int dot(int2, int2);
242+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
243+
int dot(int3, int3);
244+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
245+
int dot(int4, int4);
246+
247+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
248+
uint dot(uint, uint);
249+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
250+
uint dot(uint2, uint2);
251+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
252+
uint dot(uint3, uint3);
253+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
254+
uint dot(uint4, uint4);
255+
256+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
257+
int64_t dot(int64_t, int64_t);
258+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
259+
int64_t dot(int64_t2, int64_t2);
260+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
261+
int64_t dot(int64_t3, int64_t3);
262+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
263+
int64_t dot(int64_t4, int64_t4);
264+
265+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
266+
uint64_t dot(uint64_t, uint64_t);
267+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
268+
uint64_t dot(uint64_t2, uint64_t2);
269+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
270+
uint64_t dot(uint64_t3, uint64_t3);
271+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
272+
uint64_t dot(uint64_t4, uint64_t4);
273+
182274
//===----------------------------------------------------------------------===//
183275
// floor builtins
184276
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaChecking.cpp

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,10 +2120,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
21202120
// not a valid type, emit an error message and return true. Otherwise return
21212121
// false.
21222122
static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
2123-
QualType Ty) {
2124-
if (!Ty->getAs<VectorType>() && !ConstantMatrixType::isValidElementType(Ty)) {
2123+
QualType ArgTy, int ArgIndex) {
2124+
if (!ArgTy->getAs<VectorType>() &&
2125+
!ConstantMatrixType::isValidElementType(ArgTy)) {
21252126
return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
2126-
<< 1 << /* vector, integer or float ty*/ 0 << Ty;
2127+
<< ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
21272128
}
21282129

21292130
return false;
@@ -2958,6 +2959,10 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
29582959
}
29592960
}
29602961

2962+
if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall)) {
2963+
return ExprError();
2964+
}
2965+
29612966
// Since the target specific builtins for each arch overlap, only check those
29622967
// of the arch we are compiling for.
29632968
if (Context.BuiltinInfo.isTSBuiltin(BuiltinID)) {
@@ -5158,6 +5163,75 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
51585163
return false;
51595164
}
51605165

5166+
// Note: returning true in this case results in CheckBuiltinFunctionCall
5167+
// returning an ExprError
5168+
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5169+
switch (BuiltinID) {
5170+
case Builtin::BI__builtin_hlsl_dot: {
5171+
if (checkArgCount(*this, TheCall, 2)) {
5172+
return true;
5173+
}
5174+
Expr *Arg0 = TheCall->getArg(0);
5175+
QualType ArgTy0 = Arg0->getType();
5176+
5177+
Expr *Arg1 = TheCall->getArg(1);
5178+
QualType ArgTy1 = Arg1->getType();
5179+
5180+
auto *VecTy0 = ArgTy0->getAs<VectorType>();
5181+
auto *VecTy1 = ArgTy1->getAs<VectorType>();
5182+
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5183+
5184+
// if arg0 is bool then call Diag with err_builtin_invalid_arg_type
5185+
if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
5186+
return true;
5187+
}
5188+
5189+
// if arg1 is bool then call Diag with err_builtin_invalid_arg_type
5190+
if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
5191+
return true;
5192+
}
5193+
5194+
if (VecTy0 == nullptr && VecTy1 == nullptr) {
5195+
if (ArgTy0 != ArgTy1) {
5196+
return true;
5197+
} else {
5198+
return false;
5199+
}
5200+
}
5201+
5202+
if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
5203+
(VecTy0 != nullptr && VecTy1 == nullptr)) {
5204+
5205+
Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
5206+
<< TheCall->getDirectCallee()
5207+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5208+
TheCall->getArg(1)->getEndLoc());
5209+
return true;
5210+
}
5211+
5212+
if (VecTy0->getElementType() != VecTy1->getElementType()) {
5213+
// Note: This case should never happen. If type promotion occurs
5214+
// then element types won't be different. This diag error is here
5215+
// b\c EmitHLSLBuiltinExpr asserts on this case.
5216+
Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
5217+
<< TheCall->getDirectCallee()
5218+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5219+
TheCall->getArg(1)->getEndLoc());
5220+
return true;
5221+
}
5222+
if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
5223+
Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size)
5224+
<< TheCall->getDirectCallee()
5225+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5226+
TheCall->getArg(1)->getEndLoc());
5227+
return true;
5228+
}
5229+
break;
5230+
}
5231+
}
5232+
return false;
5233+
}
5234+
51615235
bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID,
51625236
CallExpr *TheCall) {
51635237
// position of memory order and scope arguments in the builtin
@@ -19583,7 +19657,7 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
1958319657
TheCall->setArg(0, A.get());
1958419658
QualType TyA = A.get()->getType();
1958519659

19586-
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
19660+
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
1958719661
return true;
1958819662

1958919663
TheCall->setType(TyA);
@@ -19611,7 +19685,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
1961119685
diag::err_typecheck_call_different_arg_types)
1961219686
<< TyA << TyB;
1961319687

19614-
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
19688+
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
1961519689
return true;
1961619690

1961719691
TheCall->setArg(0, A.get());

0 commit comments

Comments
 (0)