Skip to content

Commit 73cc9fd

Browse files
committed
[HLSL] Implementation of dot intrinsic
This change implements #70073 HLSL has a dot intrinsic defined here: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot The intrinsic itself is defined as a HLSL_LANG LangBuiltin in BuiltinsHLSL.td. This is used to associate all the dot product typdef defined hlsl_intrinsics.h with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp. As a side note adding the dot product intrinsic to BuiltinsHLSL.td had a significant impact on re-compile time speeds. I recommend we move the other hlsl functions here. Further it lets us tap into the existing target specifc code organizations that exist in Sema and CodeGen. In IntrinsicsDirectX.td we define the llvmIR for the dot product. A few goals were in mind for this IR. First it should operate on only vectors. Second the return type should be the vector element type. Third the second parameter vector should be of the same size as the first parameter. Finally `a dot b` should be the same as `b dot a`. In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot product though is a target specific intrinsic and so needed to establish a pattern for Target builtins via EmitDXILBuiltinExpr. The call chain looks like this now: EmitBuiltinExpr -> EmitTargetBuiltinExpr -> EmitTargetArchBuiltinExpr -> EmitDXILBuiltinExp EmitDXILBuiltinExp dot product intrinsics makes a destinction between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply. Sema.h & SemaChecking.cpp saw the addition of CheckHLSLBuiltinFunctionCall, a target specific semantic validation that can be expanded for other hlsl specific intrinsics.
1 parent 78f2eb8 commit 73cc9fd

File tree

13 files changed

+471
-9
lines changed

13 files changed

+471
-9
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//===--- BuiltinsHLSL.td - HLSL Builtin function database ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
include "clang/Basic/BuiltinsBase.td"
10+
11+
def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
12+
let Spellings = ["__builtin_hlsl_dot"];
13+
let Attributes = [NoThrow, Const, CustomTypeChecking];
14+
let Prototype = "void(...)";
15+
}

clang/include/clang/Basic/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ clang_tablegen(BuiltinsBPF.inc -gen-clang-builtins
6565
SOURCE BuiltinsBPF.td
6666
TARGET ClangBuiltinsBPF)
6767

68-
clang_tablegen(BuiltinsRISCV.inc -gen-clang-builtins
68+
clang_tablegen(BuiltinsHLSL.inc -gen-clang-builtins
69+
SOURCE BuiltinsHLSL.td
70+
TARGET ClangBuiltinsHLSL)
71+
72+
clang_tablegen(BuiltinsRISCV.inc -gen-clang-builtins
6973
SOURCE BuiltinsRISCV.td
7074
TARGET ClangBuiltinsRISCV)
7175

clang/include/clang/Basic/TargetBuiltins.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ namespace clang {
8989
};
9090
}
9191

92+
/// HLSL builtins
93+
namespace hlsl {
94+
enum {
95+
LastTIBuiltin = clang::Builtin::FirstTSBuiltin - 1,
96+
#define BUILTIN(ID, TYPE, ATTRS) BI##ID,
97+
#include "clang/Basic/BuiltinsHLSL.inc"
98+
LastTSBuiltin
99+
};
100+
} // namespace hlsl
101+
92102
/// PPC builtins
93103
namespace PPC {
94104
enum {

clang/include/clang/Sema/Sema.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14016,6 +14016,7 @@ class Sema final {
1401614016
bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1401714017
CallExpr *TheCall);
1401814018
bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
14019+
bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
1401914020
bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum);
1402014021
bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1402114022
CallExpr *TheCall);

clang/lib/Basic/Targets/DirectX.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,24 @@
1212

1313
#include "DirectX.h"
1414
#include "Targets.h"
15+
#include "clang/Basic/MacroBuilder.h"
16+
#include "clang/Basic/TargetBuiltins.h"
1517

1618
using namespace clang;
1719
using namespace clang::targets;
1820

21+
static constexpr Builtin::Info BuiltinInfo[] = {
22+
#define BUILTIN(ID, TYPE, ATTRS) \
23+
{#ID, TYPE, ATTRS, nullptr, HeaderDesc::NO_HEADER, ALL_LANGUAGES},
24+
#include "clang/Basic/BuiltinsHLSL.inc"
25+
};
26+
1927
void DirectXTargetInfo::getTargetDefines(const LangOptions &Opts,
2028
MacroBuilder &Builder) const {
2129
DefineStd(Builder, "DIRECTX", Opts);
2230
}
31+
32+
ArrayRef<Builtin::Info> DirectXTargetInfo::getTargetBuiltins() const {
33+
return llvm::ArrayRef(BuiltinInfo,
34+
clang::hlsl::LastTSBuiltin - Builtin::FirstTSBuiltin);
35+
}

clang/lib/Basic/Targets/DirectX.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ class LLVM_LIBRARY_VISIBILITY DirectXTargetInfo : public TargetInfo {
7373
return Feature == "directx";
7474
}
7575

76-
ArrayRef<Builtin::Info> getTargetBuiltins() const override {
77-
return std::nullopt;
78-
}
76+
ArrayRef<Builtin::Info> getTargetBuiltins() const override;
7977

8078
std::string_view getClobbers() const override { return ""; }
8179

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "llvm/IR/IntrinsicsAMDGPU.h"
4545
#include "llvm/IR/IntrinsicsARM.h"
4646
#include "llvm/IR/IntrinsicsBPF.h"
47+
#include "llvm/IR/IntrinsicsDirectX.h"
4748
#include "llvm/IR/IntrinsicsHexagon.h"
4849
#include "llvm/IR/IntrinsicsNVPTX.h"
4950
#include "llvm/IR/IntrinsicsPowerPC.h"
@@ -6018,6 +6019,8 @@ static Value *EmitTargetArchBuiltinExpr(CodeGenFunction *CGF,
60186019
case llvm::Triple::bpfeb:
60196020
case llvm::Triple::bpfel:
60206021
return CGF->EmitBPFBuiltinExpr(BuiltinID, E);
6022+
case llvm::Triple::dxil:
6023+
return CGF->EmitDXILBuiltinExpr(BuiltinID, E);
60216024
case llvm::Triple::x86:
60226025
case llvm::Triple::x86_64:
60236026
return CGF->EmitX86BuiltinExpr(BuiltinID, E);
@@ -17895,6 +17898,44 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
1789517898
return Arg;
1789617899
}
1789717900

17901+
Value *CodeGenFunction::EmitDXILBuiltinExpr(unsigned BuiltinID,
17902+
const CallExpr *E) {
17903+
switch (BuiltinID) {
17904+
case hlsl::BI__builtin_hlsl_dot: {
17905+
Value *Op0 = EmitScalarExpr(E->getArg(0));
17906+
Value *Op1 = EmitScalarExpr(E->getArg(1));
17907+
llvm::Type *T0 = Op0->getType();
17908+
llvm::Type *T1 = Op1->getType();
17909+
if (!T0->isVectorTy() && !T1->isVectorTy()) {
17910+
if (T0->isFloatingPointTy()) {
17911+
return Builder.CreateFMul(Op0, Op1, "dx.dot");
17912+
}
17913+
17914+
if (T0->isIntegerTy()) {
17915+
return Builder.CreateMul(Op0, Op1, "dx.dot");
17916+
}
17917+
assert(
17918+
false &&
17919+
"Dot product on a scalar is only supported on integers and floats.");
17920+
}
17921+
assert(T0->isVectorTy() && T1->isVectorTy() &&
17922+
"Dot product of vector and scalar is not supported.");
17923+
assert(T0->getScalarType() == T1->getScalarType() &&
17924+
"Dot product of vectors need the same element types.");
17925+
17926+
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
17927+
auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
17928+
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
17929+
"Dot product requires vectors to be of the same size.");
17930+
17931+
return Builder.CreateIntrinsic(
17932+
/*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
17933+
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
17934+
} break;
17935+
}
17936+
return nullptr;
17937+
}
17938+
1789817939
Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
1789917940
const CallExpr *E) {
1790017941
llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4395,6 +4395,7 @@ class CodeGenFunction : public CodeGenTypeCache {
43954395
llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
43964396
llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
43974397
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
4398+
llvm::Value *EmitDXILBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
43984399
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
43994400
const CallExpr *E);
44004401
llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,74 @@ double3 cos(double3);
144144
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
145145
double4 cos(double4);
146146

147+
//===----------------------------------------------------------------------===//
148+
// dot product builtins
149+
//===----------------------------------------------------------------------===//
150+
#ifdef __HLSL_ENABLE_16_BIT
151+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
152+
half dot(half, half);
153+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
154+
half dot(half2, half2);
155+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
156+
half dot(half3, half3);
157+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
158+
half dot(half4, half4);
159+
160+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
161+
int16_t dot(int16_t, int16_t);
162+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
163+
int16_t dot(int16_t2, int16_t2);
164+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
165+
int16_t dot(int16_t3, int16_t3);
166+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
167+
int16_t dot(int16_t4, int16_t4);
168+
169+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
170+
uint16_t dot(uint16_t, uint16_t);
171+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
172+
uint16_t dot(uint16_t2, uint16_t2);
173+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
174+
uint16_t dot(uint16_t3, uint16_t3);
175+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
176+
uint16_t dot(uint16_t4, uint16_t4);
177+
#endif
178+
179+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
180+
float dot(float, float);
181+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
182+
float dot(float2, float2);
183+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
184+
float dot(float3, float3);
185+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
186+
float dot(float4, float4);
187+
188+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
189+
double dot(double, double);
190+
191+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
192+
int dot(int, int);
193+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
194+
int dot(int2, int2);
195+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
196+
int dot(int3, int3);
197+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
198+
int dot(int4, int4);
199+
200+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
201+
uint dot(uint, uint);
202+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
203+
uint dot(uint2, uint2);
204+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
205+
uint dot(uint3, uint3);
206+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
207+
uint dot(uint4, uint4);
208+
209+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
210+
int64_t dot(int64_t, int64_t);
211+
212+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
213+
uint64_t dot(uint64_t, uint64_t);
214+
147215
//===----------------------------------------------------------------------===//
148216
// floor builtins
149217
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaChecking.cpp

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,6 +2084,8 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
20842084
return CheckBPFBuiltinFunctionCall(BuiltinID, TheCall);
20852085
case llvm::Triple::hexagon:
20862086
return CheckHexagonBuiltinFunctionCall(BuiltinID, TheCall);
2087+
case llvm::Triple::dxil:
2088+
return CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall);
20872089
case llvm::Triple::mips:
20882090
case llvm::Triple::mipsel:
20892091
case llvm::Triple::mips64:
@@ -2120,10 +2122,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
21202122
// not a valid type, emit an error message and return true. Otherwise return
21212123
// false.
21222124
static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
2123-
QualType Ty) {
2124-
if (!Ty->getAs<VectorType>() && !ConstantMatrixType::isValidElementType(Ty)) {
2125+
QualType ArgTy, int ArgIndex) {
2126+
if (!ArgTy->getAs<VectorType>() &&
2127+
!ConstantMatrixType::isValidElementType(ArgTy)) {
21252128
return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
2126-
<< 1 << /* vector, integer or float ty*/ 0 << Ty;
2129+
<< ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
21272130
}
21282131

21292132
return false;
@@ -5158,6 +5161,64 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
51585161
return false;
51595162
}
51605163

5164+
// Note: returning true in this case results in CheckBuiltinFunctionCall
5165+
// returning an ExprError
5166+
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5167+
switch (BuiltinID) {
5168+
case hlsl::BI__builtin_hlsl_dot: {
5169+
if (checkArgCount(*this, TheCall, 2)) {
5170+
return true;
5171+
}
5172+
Expr *Arg0 = TheCall->getArg(0);
5173+
QualType ArgTy0 = Arg0->getType();
5174+
5175+
Expr *Arg1 = TheCall->getArg(1);
5176+
QualType ArgTy1 = Arg1->getType();
5177+
5178+
auto *VecTy0 = ArgTy0->getAs<VectorType>();
5179+
auto *VecTy1 = ArgTy1->getAs<VectorType>();
5180+
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5181+
5182+
// if arg0 is bool then call Diag with err_builtin_invalid_arg_type
5183+
if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
5184+
return true;
5185+
}
5186+
5187+
// if arg1 is bool then call Diag with err_builtin_invalid_arg_type
5188+
if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
5189+
return true;
5190+
}
5191+
5192+
if (VecTy0 == nullptr && VecTy1 == nullptr) {
5193+
if (ArgTy0 != ArgTy1) {
5194+
return true;
5195+
} else {
5196+
return false;
5197+
}
5198+
}
5199+
5200+
if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
5201+
(VecTy0 != nullptr && VecTy1 == nullptr)) {
5202+
5203+
Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
5204+
<< TheCall->getDirectCallee()
5205+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5206+
TheCall->getArg(1)->getEndLoc());
5207+
return true;
5208+
}
5209+
5210+
if (VecTy0->getElementType() != VecTy1->getElementType()) {
5211+
return true;
5212+
}
5213+
if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
5214+
return true;
5215+
}
5216+
break;
5217+
}
5218+
}
5219+
return false;
5220+
}
5221+
51615222
bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID,
51625223
CallExpr *TheCall) {
51635224
// position of memory order and scope arguments in the builtin
@@ -19576,7 +19637,7 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
1957619637
TheCall->setArg(0, A.get());
1957719638
QualType TyA = A.get()->getType();
1957819639

19579-
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
19640+
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
1958019641
return true;
1958119642

1958219643
TheCall->setType(TyA);
@@ -19604,7 +19665,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
1960419665
diag::err_typecheck_call_different_arg_types)
1960519666
<< TyA << TyB;
1960619667

19607-
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
19668+
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
1960819669
return true;
1960919670

1961019671
TheCall->setArg(0, A.get());

0 commit comments

Comments
 (0)