Skip to content

Commit 4103940

Browse files
committed
[HLSL] Implementation of dot intrinsic
This change implements llvm#70073 HLSL has a dot intrinsic defined here: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot The intrinsic itself is defined as a HLSL_LANG LangBuiltin in BuiltinsHLSL.td. This is used to associate all the dot product typdef defined hlsl_intrinsics.h with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp. As a side note adding the dot product intrinsic to BuiltinsHLSL.td had a significant impact on re-compile time speeds. I recommend we move the other hlsl functions here. Further it lets us tap into the existing target specifc code organizations that exist in Sema and CodeGen. In IntrinsicsDirectX.td we define the llvmIR for the dot product. A few goals were in mind for this IR. First it should operate on only vectors. Second the return type should be the vector element type. Third the second parameter vector should be of the same size as the first parameter. Finally `a dot b` should be the same as `b dot a`. In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot product though is a target specific intrinsic and so needed to establish a pattern for Target builtins via EmitDXILBuiltinExpr. The call chain looks like this now: EmitBuiltinExpr -> EmitTargetBuiltinExpr -> EmitTargetArchBuiltinExpr -> EmitDXILBuiltinExp EmitDXILBuiltinExp dot product intrinsics makes a destinction between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply. Sema.h & SemaChecking.cpp saw the addition of CheckHLSLBuiltinFunctionCall, a target specific semantic validation that can be expanded for other hlsl specific intrinsics.
1 parent 067d277 commit 4103940

File tree

13 files changed

+470
-8
lines changed

13 files changed

+470
-8
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//===--- BuiltinsHLSL.td - HLSL Builtin function database ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
include "clang/Basic/BuiltinsBase.td"
10+
11+
def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
12+
let Spellings = ["__builtin_hlsl_dot"];
13+
let Attributes = [NoThrow, Const, CustomTypeChecking];
14+
let Prototype = "void(...)";
15+
}

clang/include/clang/Basic/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ clang_tablegen(BuiltinsBPF.inc -gen-clang-builtins
6565
SOURCE BuiltinsBPF.td
6666
TARGET ClangBuiltinsBPF)
6767

68+
clang_tablegen(BuiltinsHLSL.inc -gen-clang-builtins
69+
SOURCE BuiltinsHLSL.td
70+
TARGET ClangBuiltinsHLSL)
71+
6872
# ARM NEON and MVE
6973
clang_tablegen(arm_neon.inc -gen-arm-neon-sema
7074
SOURCE arm_neon.td

clang/include/clang/Basic/TargetBuiltins.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ namespace clang {
8989
};
9090
}
9191

92+
/// HLSL builtins
93+
namespace hlsl {
94+
enum {
95+
LastTIBuiltin = clang::Builtin::FirstTSBuiltin - 1,
96+
#define BUILTIN(ID, TYPE, ATTRS) BI##ID,
97+
#include "clang/Basic/BuiltinsHLSL.inc"
98+
LastTSBuiltin
99+
};
100+
} // namespace hlsl
101+
92102
/// PPC builtins
93103
namespace PPC {
94104
enum {

clang/include/clang/Sema/Sema.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14017,6 +14017,7 @@ class Sema final {
1401714017
bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1401814018
CallExpr *TheCall);
1401914019
bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
14020+
bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
1402014021
bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum);
1402114022
bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1402214023
CallExpr *TheCall);

clang/lib/Basic/Targets/DirectX.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,24 @@
1212

1313
#include "DirectX.h"
1414
#include "Targets.h"
15+
#include "clang/Basic/MacroBuilder.h"
16+
#include "clang/Basic/TargetBuiltins.h"
1517

1618
using namespace clang;
1719
using namespace clang::targets;
1820

21+
static constexpr Builtin::Info BuiltinInfo[] = {
22+
#define BUILTIN(ID, TYPE, ATTRS) \
23+
{#ID, TYPE, ATTRS, nullptr, HeaderDesc::NO_HEADER, ALL_LANGUAGES},
24+
#include "clang/Basic/BuiltinsHLSL.inc"
25+
};
26+
1927
void DirectXTargetInfo::getTargetDefines(const LangOptions &Opts,
2028
MacroBuilder &Builder) const {
2129
DefineStd(Builder, "DIRECTX", Opts);
2230
}
31+
32+
ArrayRef<Builtin::Info> DirectXTargetInfo::getTargetBuiltins() const {
33+
return llvm::ArrayRef(BuiltinInfo,
34+
clang::hlsl::LastTSBuiltin - Builtin::FirstTSBuiltin);
35+
}

clang/lib/Basic/Targets/DirectX.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ class LLVM_LIBRARY_VISIBILITY DirectXTargetInfo : public TargetInfo {
7373
return Feature == "directx";
7474
}
7575

76-
ArrayRef<Builtin::Info> getTargetBuiltins() const override {
77-
return std::nullopt;
78-
}
76+
ArrayRef<Builtin::Info> getTargetBuiltins() const override;
7977

8078
std::string_view getClobbers() const override { return ""; }
8179

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "llvm/IR/IntrinsicsAMDGPU.h"
4545
#include "llvm/IR/IntrinsicsARM.h"
4646
#include "llvm/IR/IntrinsicsBPF.h"
47+
#include "llvm/IR/IntrinsicsDirectX.h"
4748
#include "llvm/IR/IntrinsicsHexagon.h"
4849
#include "llvm/IR/IntrinsicsNVPTX.h"
4950
#include "llvm/IR/IntrinsicsPowerPC.h"
@@ -6018,6 +6019,8 @@ static Value *EmitTargetArchBuiltinExpr(CodeGenFunction *CGF,
60186019
case llvm::Triple::bpfeb:
60196020
case llvm::Triple::bpfel:
60206021
return CGF->EmitBPFBuiltinExpr(BuiltinID, E);
6022+
case llvm::Triple::dxil:
6023+
return CGF->EmitDXILBuiltinExpr(BuiltinID, E);
60216024
case llvm::Triple::x86:
60226025
case llvm::Triple::x86_64:
60236026
return CGF->EmitX86BuiltinExpr(BuiltinID, E);
@@ -17895,6 +17898,44 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
1789517898
return Arg;
1789617899
}
1789717900

17901+
Value *CodeGenFunction::EmitDXILBuiltinExpr(unsigned BuiltinID,
17902+
const CallExpr *E) {
17903+
switch (BuiltinID) {
17904+
case hlsl::BI__builtin_hlsl_dot: {
17905+
Value *Op0 = EmitScalarExpr(E->getArg(0));
17906+
Value *Op1 = EmitScalarExpr(E->getArg(1));
17907+
llvm::Type *T0 = Op0->getType();
17908+
llvm::Type *T1 = Op1->getType();
17909+
if (!T0->isVectorTy() && !T1->isVectorTy()) {
17910+
if (T0->isFloatingPointTy()) {
17911+
return Builder.CreateFMul(Op0, Op1, "dx.dot");
17912+
}
17913+
17914+
if (T0->isIntegerTy()) {
17915+
return Builder.CreateMul(Op0, Op1, "dx.dot");
17916+
}
17917+
assert(
17918+
false &&
17919+
"Dot product on a scalar is only supported on integers and floats.");
17920+
}
17921+
assert(T0->isVectorTy() && T1->isVectorTy() &&
17922+
"Dot product of vector and scalar is not supported.");
17923+
assert(T0->getScalarType() == T1->getScalarType() &&
17924+
"Dot product of vectors need the same element types.");
17925+
17926+
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
17927+
auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
17928+
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
17929+
"Dot product requires vectors to be of the same size.");
17930+
17931+
return Builder.CreateIntrinsic(
17932+
/*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
17933+
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
17934+
} break;
17935+
}
17936+
return nullptr;
17937+
}
17938+
1789817939
Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
1789917940
const CallExpr *E) {
1790017941
llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4392,6 +4392,7 @@ class CodeGenFunction : public CodeGenTypeCache {
43924392
llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
43934393
llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
43944394
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
4395+
llvm::Value *EmitDXILBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
43954396
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
43964397
const CallExpr *E);
43974398
llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,74 @@ double3 cos(double3);
144144
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
145145
double4 cos(double4);
146146

147+
//===----------------------------------------------------------------------===//
148+
// dot product builtins
149+
//===----------------------------------------------------------------------===//
150+
#ifdef __HLSL_ENABLE_16_BIT
151+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
152+
half dot(half, half);
153+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
154+
half dot(half2, half2);
155+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
156+
half dot(half3, half3);
157+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
158+
half dot(half4, half4);
159+
160+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
161+
int16_t dot(int16_t, int16_t);
162+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
163+
int16_t dot(int16_t2, int16_t2);
164+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
165+
int16_t dot(int16_t3, int16_t3);
166+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
167+
int16_t dot(int16_t4, int16_t4);
168+
169+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
170+
uint16_t dot(uint16_t, uint16_t);
171+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
172+
uint16_t dot(uint16_t2, uint16_t2);
173+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
174+
uint16_t dot(uint16_t3, uint16_t3);
175+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
176+
uint16_t dot(uint16_t4, uint16_t4);
177+
#endif
178+
179+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
180+
float dot(float, float);
181+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
182+
float dot(float2, float2);
183+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
184+
float dot(float3, float3);
185+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
186+
float dot(float4, float4);
187+
188+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
189+
double dot(double, double);
190+
191+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
192+
int dot(int, int);
193+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
194+
int dot(int2, int2);
195+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
196+
int dot(int3, int3);
197+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
198+
int dot(int4, int4);
199+
200+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
201+
uint dot(uint, uint);
202+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
203+
uint dot(uint2, uint2);
204+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
205+
uint dot(uint3, uint3);
206+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
207+
uint dot(uint4, uint4);
208+
209+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
210+
int64_t dot(int64_t, int64_t);
211+
212+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
213+
uint64_t dot(uint64_t, uint64_t);
214+
147215
//===----------------------------------------------------------------------===//
148216
// floor builtins
149217
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaChecking.cpp

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,6 +2084,8 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
20842084
return CheckBPFBuiltinFunctionCall(BuiltinID, TheCall);
20852085
case llvm::Triple::hexagon:
20862086
return CheckHexagonBuiltinFunctionCall(BuiltinID, TheCall);
2087+
case llvm::Triple::dxil:
2088+
return CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall);
20872089
case llvm::Triple::mips:
20882090
case llvm::Triple::mipsel:
20892091
case llvm::Triple::mips64:
@@ -2120,10 +2122,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
21202122
// not a valid type, emit an error message and return true. Otherwise return
21212123
// false.
21222124
static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
2123-
QualType Ty) {
2124-
if (!Ty->getAs<VectorType>() && !ConstantMatrixType::isValidElementType(Ty)) {
2125+
QualType ArgTy, int ArgIndex) {
2126+
if (!ArgTy->getAs<VectorType>() &&
2127+
!ConstantMatrixType::isValidElementType(ArgTy)) {
21252128
return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
2126-
<< 1 << /* vector, integer or float ty*/ 0 << Ty;
2129+
<< ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
21272130
}
21282131

21292132
return false;
@@ -5158,6 +5161,64 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
51585161
return false;
51595162
}
51605163

5164+
// Note: returning true in this case results in CheckBuiltinFunctionCall
5165+
// returning an ExprError
5166+
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5167+
switch (BuiltinID) {
5168+
case hlsl::BI__builtin_hlsl_dot: {
5169+
if (checkArgCount(*this, TheCall, 2)) {
5170+
return true;
5171+
}
5172+
Expr *Arg0 = TheCall->getArg(0);
5173+
QualType ArgTy0 = Arg0->getType();
5174+
5175+
Expr *Arg1 = TheCall->getArg(1);
5176+
QualType ArgTy1 = Arg1->getType();
5177+
5178+
auto *VecTy0 = ArgTy0->getAs<VectorType>();
5179+
auto *VecTy1 = ArgTy1->getAs<VectorType>();
5180+
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5181+
5182+
// if arg0 is bool then call Diag with err_builtin_invalid_arg_type
5183+
if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
5184+
return true;
5185+
}
5186+
5187+
// if arg1 is bool then call Diag with err_builtin_invalid_arg_type
5188+
if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
5189+
return true;
5190+
}
5191+
5192+
if (VecTy0 == nullptr && VecTy1 == nullptr) {
5193+
if (ArgTy0 != ArgTy1) {
5194+
return true;
5195+
} else {
5196+
return false;
5197+
}
5198+
}
5199+
5200+
if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
5201+
(VecTy0 != nullptr && VecTy1 == nullptr)) {
5202+
5203+
Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
5204+
<< TheCall->getDirectCallee()
5205+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5206+
TheCall->getArg(1)->getEndLoc());
5207+
return true;
5208+
}
5209+
5210+
if (VecTy0->getElementType() != VecTy1->getElementType()) {
5211+
return true;
5212+
}
5213+
if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
5214+
return true;
5215+
}
5216+
break;
5217+
}
5218+
}
5219+
return false;
5220+
}
5221+
51615222
bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID,
51625223
CallExpr *TheCall) {
51635224
// position of memory order and scope arguments in the builtin
@@ -19577,7 +19638,7 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
1957719638
TheCall->setArg(0, A.get());
1957819639
QualType TyA = A.get()->getType();
1957919640

19580-
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
19641+
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
1958119642
return true;
1958219643

1958319644
TheCall->setType(TyA);
@@ -19605,7 +19666,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
1960519666
diag::err_typecheck_call_different_arg_types)
1960619667
<< TyA << TyB;
1960719668

19608-
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
19669+
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
1960919670
return true;
1961019671

1961119672
TheCall->setArg(0, A.get());

0 commit comments

Comments
 (0)