Skip to content

[HLSL] Implementation of dot intrinsic #81190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4524,6 +4524,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void*(unsigned char)";
}

def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -14057,6 +14057,7 @@ class Sema final {
bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
CallExpr *TheCall);
bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum);
bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
CallExpr *TheCall);
Expand Down Expand Up @@ -14122,6 +14123,8 @@ class Sema final {

bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc);

bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
Expand Down
51 changes: 51 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsARM.h"
#include "llvm/IR/IntrinsicsBPF.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/IntrinsicsHexagon.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/IntrinsicsPowerPC.h"
Expand Down Expand Up @@ -5982,6 +5983,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
}

// EmitHLSLBuiltinExpr will check getLangOpts().HLSL
if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
return RValue::get(V);

if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
return EmitHipStdParUnsupportedBuiltin(this, FD);

Expand Down Expand Up @@ -17959,6 +17964,52 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
return Arg;
}

Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
const CallExpr *E) {
if (!getLangOpts().HLSL)
return nullptr;

switch (BuiltinID) {
case Builtin::BI__builtin_hlsl_dot: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
Value *Op1 = EmitScalarExpr(E->getArg(1));
llvm::Type *T0 = Op0->getType();
llvm::Type *T1 = Op1->getType();
if (!T0->isVectorTy() && !T1->isVectorTy()) {
if (T0->isFloatingPointTy())
return Builder.CreateFMul(Op0, Op1, "dx.dot");

if (T0->isIntegerTy())
return Builder.CreateMul(Op0, Op1, "dx.dot");

// Bools should have been promoted
llvm_unreachable(
"Scalar dot product is only supported on ints and floats.");
}
// A VectorSplat should have happened
assert(T0->isVectorTy() && T1->isVectorTy() &&
"Dot product of vector and scalar is not supported.");

// A vector sext or sitofp should have happened
assert(T0->getScalarType() == T1->getScalarType() &&
"Dot product of vectors need the same element types.");

[[maybe_unused]] auto *VecTy0 =
E->getArg(0)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *VecTy1 =
E->getArg(1)->getType()->getAs<VectorType>();
// A HLSLVectorTruncation should have happend
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
"Dot product requires vectors to be of the same size.");

return Builder.CreateIntrinsic(
/*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
} break;
}
return nullptr;
}

Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
const CallExpr *E) {
llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4405,6 +4405,7 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
const CallExpr *E);
llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
Expand Down
98 changes: 98 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,104 @@ double3 cos(double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
double4 cos(double4);

//===----------------------------------------------------------------------===//
// dot product builtins
//===----------------------------------------------------------------------===//

/// \fn K dot(T X, T Y)
/// \brief Return the dot product (a scalar value) of \a X and \a Y.
/// \param X The X input value.
/// \param Y The Y input value.

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half, half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half2, half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half3, half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half4, half4);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int16_t dot(int16_t, int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int16_t dot(int16_t2, int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int16_t dot(int16_t3, int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int16_t dot(int16_t4, int16_t4);

_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint16_t dot(uint16_t, uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint16_t dot(uint16_t2, uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint16_t dot(uint16_t3, uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint16_t dot(uint16_t4, uint16_t4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
float dot(float, float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
float dot(float2, float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
float dot(float3, float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
float dot(float4, float4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
double dot(double, double);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int dot(int, int);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int dot(int2, int2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int dot(int3, int3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int dot(int4, int4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint dot(uint, uint);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint dot(uint2, uint2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint dot(uint3, uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint dot(uint4, uint4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int64_t dot(int64_t, int64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int64_t dot(int64_t2, int64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int64_t dot(int64_t3, int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
int64_t dot(int64_t4, int64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t, uint64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t2, uint64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t3, uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
uint64_t dot(uint64_t4, uint64_t4);

//===----------------------------------------------------------------------===//
// floor builtins
//===----------------------------------------------------------------------===//
Expand Down
103 changes: 95 additions & 8 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2120,10 +2120,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
// not a valid type, emit an error message and return true. Otherwise return
// false.
static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
QualType Ty) {
if (!Ty->getAs<VectorType>() && !ConstantMatrixType::isValidElementType(Ty)) {
QualType ArgTy, int ArgIndex) {
if (!ArgTy->getAs<VectorType>() &&
!ConstantMatrixType::isValidElementType(ArgTy)) {
return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
<< 1 << /* vector, integer or float ty*/ 0 << Ty;
<< ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
}

return false;
Expand Down Expand Up @@ -2961,6 +2962,9 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
}
}

if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall))
return ExprError();

// Since the target specific builtins for each arch overlap, only check those
// of the arch we are compiling for.
if (Context.BuiltinInfo.isTSBuiltin(BuiltinID)) {
Expand Down Expand Up @@ -5161,6 +5165,70 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
return false;
}

// Helper function for CheckHLSLBuiltinFunctionCall
bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 1);
ExprResult A = TheCall->getArg(0);
ExprResult B = TheCall->getArg(1);
QualType ArgTyA = A.get()->getType();
QualType ArgTyB = B.get()->getType();
auto *VecTyA = ArgTyA->getAs<VectorType>();
auto *VecTyB = ArgTyB->getAs<VectorType>();
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
if (VecTyA == nullptr && VecTyB == nullptr)
return false;

if (VecTyA && VecTyB) {
bool retValue = false;
if (VecTyA->getElementType() != VecTyB->getElementType()) {
// Note: type promotion is intended to be handeled via the intrinsics
// and not the builtin itself.
S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
<< TheCall->getDirectCallee()
<< SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
retValue = true;
}
if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
// if we get here a HLSLVectorTruncation is needed.
S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
retValue = true;
}

if (retValue)
TheCall->setType(VecTyA->getElementType());

return retValue;
}

// Note: if we get here one of the args is a scalar which
// requires a VectorSplat on Arg0 or Arg1
S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
<< TheCall->getDirectCallee()
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
TheCall->getArg(1)->getEndLoc());
return true;
}

// Note: returning true in this case results in CheckBuiltinFunctionCall
// returning an ExprError
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
switch (BuiltinID) {
case Builtin::BI__builtin_hlsl_dot: {
if (checkArgCount(*this, TheCall, 2))
return true;
if (CheckVectorElementCallArgs(this, TheCall))
return true;
if (SemaBuiltinVectorToScalarMath(TheCall))
return true;
break;
}
}
return false;
}

bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID,
CallExpr *TheCall) {
// position of memory order and scope arguments in the builtin
Expand Down Expand Up @@ -19594,23 +19662,43 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
TheCall->setArg(0, A.get());
QualType TyA = A.get()->getType();

if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
return true;

TheCall->setType(TyA);
return false;
}

bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
QualType Res;
if (SemaBuiltinVectorMath(TheCall, Res))
return true;
TheCall->setType(Res);
return false;
}

bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) {
QualType Res;
if (SemaBuiltinVectorMath(TheCall, Res))
return true;

if (auto *VecTy0 = Res->getAs<VectorType>())
TheCall->setType(VecTy0->getElementType());
else
TheCall->setType(Res);

return false;
}

bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
if (checkArgCount(*this, TheCall, 2))
return true;

ExprResult A = TheCall->getArg(0);
ExprResult B = TheCall->getArg(1);
// Do standard promotions between the two arguments, returning their common
// type.
QualType Res =
UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
Res = UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
if (A.isInvalid() || B.isInvalid())
return true;

Expand All @@ -19622,12 +19710,11 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
diag::err_typecheck_call_different_arg_types)
<< TyA << TyB;

if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
return true;

TheCall->setArg(0, A.get());
TheCall->setArg(1, B.get());
TheCall->setType(Res);
return false;
}

Expand Down
Loading