Skip to content

Commit fad7e6e

Browse files
committed
[HLSL] Implementation of dot intrinsic
This change implements #70073 HLSL has a dot intrinsic defined here: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot The intrinsic itself is defined as a HLSL_LANG LangBuiltin in Builtins.td. This is used to associate all the dot product typdef defined hlsl_intrinsics.h with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp. In IntrinsicsDirectX.td we define the llvmIR for the dot product. A few goals were in mind for this IR. First it should operate on only vectors. Second the return type should be the vector element type. Third the second parameter vector should be of the same size as the first parameter. Finally `a dot b` should be the same as `b dot a`. In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot product though is language specific intrinsic and so is guarded behind getLangOpts().HLSL. The call chain looks like this: EmitBuiltinExpr -> EmitHLSLBuiltinExp EmitHLSLBuiltinExp dot product intrinsics makes a destinction between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply. Sema.h & SemaChecking.cpp saw the addition of CheckHLSLBuiltinFunctionCall, a language specific semantic validation that can be expanded for other hlsl specific intrinsics.
1 parent 5ccf546 commit fad7e6e

File tree

10 files changed

+497
-5
lines changed

10 files changed

+497
-5
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4524,6 +4524,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
45244524
let Prototype = "void*(unsigned char)";
45254525
}
45264526

4527+
def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
4528+
let Spellings = ["__builtin_hlsl_dot"];
4529+
let Attributes = [NoThrow, Const, CustomTypeChecking];
4530+
let Prototype = "void(...)";
4531+
}
4532+
45274533
// Builtins for XRay.
45284534
def XRayCustomEvent : Builtin {
45294535
let Spellings = ["__xray_customevent"];

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10270,6 +10270,8 @@ def err_vec_builtin_non_vector : Error<
1027010270
"first two arguments to %0 must be vectors">;
1027110271
def err_vec_builtin_incompatible_vector : Error<
1027210272
"first two arguments to %0 must have the same type">;
10273+
def err_vec_builtin_incompatible_size : Error<
10274+
"first two arguments to %0 must have the same size">;
1027310275
def err_vsx_builtin_nonconstant_argument : Error<
1027410276
"argument %0 to %1 must be a 2-bit unsigned literal (i.e. 0, 1, 2 or 3)">;
1027510277

clang/include/clang/Sema/Sema.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14057,6 +14057,7 @@ class Sema final {
1405714057
bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1405814058
CallExpr *TheCall);
1405914059
bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
14060+
bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
1406014061
bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum);
1406114062
bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1406214063
CallExpr *TheCall);

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "llvm/IR/IntrinsicsAMDGPU.h"
4545
#include "llvm/IR/IntrinsicsARM.h"
4646
#include "llvm/IR/IntrinsicsBPF.h"
47+
#include "llvm/IR/IntrinsicsDirectX.h"
4748
#include "llvm/IR/IntrinsicsHexagon.h"
4849
#include "llvm/IR/IntrinsicsNVPTX.h"
4950
#include "llvm/IR/IntrinsicsPowerPC.h"
@@ -5982,6 +5983,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
59825983
llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
59835984
}
59845985

5986+
// EmitHLSLBuiltinExpr will check getLangOpts().HLSL
5987+
if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
5988+
return RValue::get(V);
5989+
59855990
if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
59865991
return EmitHipStdParUnsupportedBuiltin(this, FD);
59875992

@@ -17959,6 +17964,50 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
1795917964
return Arg;
1796017965
}
1796117966

17967+
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
17968+
const CallExpr *E) {
17969+
if (!getLangOpts().HLSL)
17970+
return nullptr;
17971+
17972+
switch (BuiltinID) {
17973+
case Builtin::BI__builtin_hlsl_dot: {
17974+
Value *Op0 = EmitScalarExpr(E->getArg(0));
17975+
Value *Op1 = EmitScalarExpr(E->getArg(1));
17976+
llvm::Type *T0 = Op0->getType();
17977+
llvm::Type *T1 = Op1->getType();
17978+
if (!T0->isVectorTy() && !T1->isVectorTy()) {
17979+
if (T0->isFloatingPointTy()) {
17980+
return Builder.CreateFMul(Op0, Op1, "dx.dot");
17981+
}
17982+
17983+
if (T0->isIntegerTy()) {
17984+
return Builder.CreateMul(Op0, Op1, "dx.dot");
17985+
}
17986+
assert(
17987+
false &&
17988+
"Dot product on a scalar is only supported on integers and floats.");
17989+
}
17990+
assert(T0->isVectorTy() && T1->isVectorTy() &&
17991+
"Dot product of vector and scalar is not supported.");
17992+
17993+
// NOTE: this assert will need to be revisited after overload resoltion
17994+
// PR merges.
17995+
assert(T0->getScalarType() == T1->getScalarType() &&
17996+
"Dot product of vectors need the same element types.");
17997+
17998+
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
17999+
auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
18000+
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
18001+
"Dot product requires vectors to be of the same size.");
18002+
18003+
return Builder.CreateIntrinsic(
18004+
/*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
18005+
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
18006+
} break;
18007+
}
18008+
return nullptr;
18009+
}
18010+
1796218011
Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
1796318012
const CallExpr *E) {
1796418013
llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4405,6 +4405,7 @@ class CodeGenFunction : public CodeGenTypeCache {
44054405
llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
44064406
llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
44074407
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
4408+
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
44084409
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
44094410
const CallExpr *E);
44104411
llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,98 @@ double3 cos(double3);
179179
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
180180
double4 cos(double4);
181181

182+
//===----------------------------------------------------------------------===//
183+
// dot product builtins
184+
//===----------------------------------------------------------------------===//
185+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
186+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
187+
half dot(half, half);
188+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
189+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
190+
half dot(half2, half2);
191+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
192+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
193+
half dot(half3, half3);
194+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
195+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
196+
half dot(half4, half4);
197+
198+
#ifdef __HLSL_ENABLE_16_BIT
199+
_HLSL_AVAILABILITY(shadermodel, 6.2)
200+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
201+
int16_t dot(int16_t, int16_t);
202+
_HLSL_AVAILABILITY(shadermodel, 6.2)
203+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
204+
int16_t dot(int16_t2, int16_t2);
205+
_HLSL_AVAILABILITY(shadermodel, 6.2)
206+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
207+
int16_t dot(int16_t3, int16_t3);
208+
_HLSL_AVAILABILITY(shadermodel, 6.2)
209+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
210+
int16_t dot(int16_t4, int16_t4);
211+
212+
_HLSL_AVAILABILITY(shadermodel, 6.2)
213+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
214+
uint16_t dot(uint16_t, uint16_t);
215+
_HLSL_AVAILABILITY(shadermodel, 6.2)
216+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
217+
uint16_t dot(uint16_t2, uint16_t2);
218+
_HLSL_AVAILABILITY(shadermodel, 6.2)
219+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
220+
uint16_t dot(uint16_t3, uint16_t3);
221+
_HLSL_AVAILABILITY(shadermodel, 6.2)
222+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
223+
uint16_t dot(uint16_t4, uint16_t4);
224+
#endif
225+
226+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
227+
float dot(float, float);
228+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
229+
float dot(float2, float2);
230+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
231+
float dot(float3, float3);
232+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
233+
float dot(float4, float4);
234+
235+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
236+
double dot(double, double);
237+
238+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
239+
int dot(int, int);
240+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
241+
int dot(int2, int2);
242+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
243+
int dot(int3, int3);
244+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
245+
int dot(int4, int4);
246+
247+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
248+
uint dot(uint, uint);
249+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
250+
uint dot(uint2, uint2);
251+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
252+
uint dot(uint3, uint3);
253+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
254+
uint dot(uint4, uint4);
255+
256+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
257+
int64_t dot(int64_t, int64_t);
258+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
259+
int64_t dot(int64_t2, int64_t2);
260+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
261+
int64_t dot(int64_t3, int64_t3);
262+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
263+
int64_t dot(int64_t4, int64_t4);
264+
265+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
266+
uint64_t dot(uint64_t, uint64_t);
267+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
268+
uint64_t dot(uint64_t2, uint64_t2);
269+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
270+
uint64_t dot(uint64_t3, uint64_t3);
271+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
272+
uint64_t dot(uint64_t4, uint64_t4);
273+
182274
//===----------------------------------------------------------------------===//
183275
// floor builtins
184276
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaChecking.cpp

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,10 +2120,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
21202120
// not a valid type, emit an error message and return true. Otherwise return
21212121
// false.
21222122
static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
2123-
QualType Ty) {
2124-
if (!Ty->getAs<VectorType>() && !ConstantMatrixType::isValidElementType(Ty)) {
2123+
QualType ArgTy, int ArgIndex) {
2124+
if (!ArgTy->getAs<VectorType>() &&
2125+
!ConstantMatrixType::isValidElementType(ArgTy)) {
21252126
return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
2126-
<< 1 << /* vector, integer or float ty*/ 0 << Ty;
2127+
<< ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
21272128
}
21282129

21292130
return false;
@@ -2961,6 +2962,10 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
29612962
}
29622963
}
29632964

2965+
if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall)) {
2966+
return ExprError();
2967+
}
2968+
29642969
// Since the target specific builtins for each arch overlap, only check those
29652970
// of the arch we are compiling for.
29662971
if (Context.BuiltinInfo.isTSBuiltin(BuiltinID)) {
@@ -5161,6 +5166,75 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
51615166
return false;
51625167
}
51635168

5169+
// Note: returning true in this case results in CheckBuiltinFunctionCall
5170+
// returning an ExprError
5171+
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5172+
switch (BuiltinID) {
5173+
case Builtin::BI__builtin_hlsl_dot: {
5174+
if (checkArgCount(*this, TheCall, 2)) {
5175+
return true;
5176+
}
5177+
Expr *Arg0 = TheCall->getArg(0);
5178+
QualType ArgTy0 = Arg0->getType();
5179+
5180+
Expr *Arg1 = TheCall->getArg(1);
5181+
QualType ArgTy1 = Arg1->getType();
5182+
5183+
auto *VecTy0 = ArgTy0->getAs<VectorType>();
5184+
auto *VecTy1 = ArgTy1->getAs<VectorType>();
5185+
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5186+
5187+
// if arg0 is bool then call Diag with err_builtin_invalid_arg_type
5188+
if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
5189+
return true;
5190+
}
5191+
5192+
// if arg1 is bool then call Diag with err_builtin_invalid_arg_type
5193+
if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
5194+
return true;
5195+
}
5196+
5197+
if (VecTy0 == nullptr && VecTy1 == nullptr) {
5198+
if (ArgTy0 != ArgTy1) {
5199+
return true;
5200+
} else {
5201+
return false;
5202+
}
5203+
}
5204+
5205+
if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
5206+
(VecTy0 != nullptr && VecTy1 == nullptr)) {
5207+
5208+
Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
5209+
<< TheCall->getDirectCallee()
5210+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5211+
TheCall->getArg(1)->getEndLoc());
5212+
return true;
5213+
}
5214+
5215+
if (VecTy0->getElementType() != VecTy1->getElementType()) {
5216+
// Note: This case should never happen. If type promotion occurs
5217+
// then element types won't be different. This diag error is here
5218+
// b\c EmitHLSLBuiltinExpr asserts on this case.
5219+
Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
5220+
<< TheCall->getDirectCallee()
5221+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5222+
TheCall->getArg(1)->getEndLoc());
5223+
return true;
5224+
}
5225+
if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
5226+
Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size)
5227+
<< TheCall->getDirectCallee()
5228+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5229+
TheCall->getArg(1)->getEndLoc());
5230+
return true;
5231+
}
5232+
break;
5233+
}
5234+
}
5235+
return false;
5236+
}
5237+
51645238
bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID,
51655239
CallExpr *TheCall) {
51665240
// position of memory order and scope arguments in the builtin
@@ -19594,7 +19668,7 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
1959419668
TheCall->setArg(0, A.get());
1959519669
QualType TyA = A.get()->getType();
1959619670

19597-
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
19671+
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
1959819672
return true;
1959919673

1960019674
TheCall->setType(TyA);
@@ -19622,7 +19696,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
1962219696
diag::err_typecheck_call_different_arg_types)
1962319697
<< TyA << TyB;
1962419698

19625-
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
19699+
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
1962619700
return true;
1962719701

1962819702
TheCall->setArg(0, A.get());

0 commit comments

Comments
 (0)