Skip to content

Commit 1b2d11d

Browse files
authored
Add normalize builtins and normalize HLSL function to DirectX and SPIR-V backend (#102683)
This PR adds the normalize intrinsic and an HLSL function that uses it. The SPIRV backend is also implemented. Used #101256 as a reference, along with #102243 Fixes #99139
1 parent 643a208 commit 1b2d11d

File tree

14 files changed

+448
-0
lines changed

14 files changed

+448
-0
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4725,6 +4725,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
47254725
let Prototype = "void(...)";
47264726
}
47274727

4728+
def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
4729+
let Spellings = ["__builtin_hlsl_normalize"];
4730+
let Attributes = [NoThrow, Const];
4731+
let Prototype = "void(...)";
4732+
}
4733+
47284734
def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
47294735
let Spellings = ["__builtin_hlsl_elementwise_rcp"];
47304736
let Attributes = [NoThrow, Const];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18584,6 +18584,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1858418584
CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
1858518585
nullptr, "hlsl.length");
1858618586
}
18587+
case Builtin::BI__builtin_hlsl_normalize: {
18588+
Value *X = EmitScalarExpr(E->getArg(0));
18589+
18590+
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
18591+
"normalize operand must have a float representation");
18592+
18593+
return Builder.CreateIntrinsic(
18594+
/*ReturnType=*/X->getType(),
18595+
CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
18596+
nullptr, "hlsl.normalize");
18597+
}
1858718598
case Builtin::BI__builtin_hlsl_elementwise_frac: {
1858818599
Value *Op0 = EmitScalarExpr(E->getArg(0));
1858918600
if (!E->getArg(0)->getType()->hasFloatingRepresentation())

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class CGHLSLRuntime {
7777
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
7878
GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
7979
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
80+
GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
8081
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
8182
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
8283

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,38 @@ double3 min(double3, double3);
13521352
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
13531353
double4 min(double4, double4);
13541354

1355+
//===----------------------------------------------------------------------===//
1356+
// normalize builtins
1357+
//===----------------------------------------------------------------------===//
1358+
1359+
/// \fn T normalize(T x)
1360+
/// \brief Returns the normalized unit vector of the specified floating-point
1361+
/// vector. \param x [in] The vector of floats.
1362+
///
1363+
/// Normalize is based on the following formula: x / length(x).
1364+
1365+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1366+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
1367+
half normalize(half);
1368+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1369+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
1370+
half2 normalize(half2);
1371+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1372+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
1373+
half3 normalize(half3);
1374+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1375+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
1376+
half4 normalize(half4);
1377+
1378+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
1379+
float normalize(float);
1380+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
1381+
float2 normalize(float2);
1382+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
1383+
float3 normalize(float3);
1384+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
1385+
float4 normalize(float4);
1386+
13551387
//===----------------------------------------------------------------------===//
13561388
// pow builtins
13571389
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,6 +1108,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
11081108
return true;
11091109
break;
11101110
}
1111+
case Builtin::BI__builtin_hlsl_normalize: {
1112+
if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
1113+
return true;
1114+
if (SemaRef.checkArgCount(TheCall, 1))
1115+
return true;
1116+
1117+
ExprResult A = TheCall->getArg(0);
1118+
QualType ArgTyA = A.get()->getType();
1119+
// return type is the same as the input type
1120+
TheCall->setType(ArgTyA);
1121+
break;
1122+
}
11111123
// Note these are llvm builtins that we want to catch invalid intrinsic
11121124
// generation. Normal handling of these builitns will occur elsewhere.
11131125
case Builtin::BI__builtin_elementwise_bitreverse: {
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
2+
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
3+
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
4+
// RUN: --check-prefixes=CHECK,DXIL_CHECK,DXIL_NATIVE_HALF,NATIVE_HALF
5+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
6+
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
7+
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,DXIL_CHECK,NO_HALF,DXIL_NO_HALF
8+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
9+
// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
10+
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
11+
// RUN: --check-prefixes=CHECK,NATIVE_HALF,SPIR_NATIVE_HALF,SPIR_CHECK
12+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
13+
// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
14+
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF,SPIR_NO_HALF,SPIR_CHECK
15+
16+
// DXIL_NATIVE_HALF: define noundef half @
17+
// SPIR_NATIVE_HALF: define spir_func noundef half @
18+
// DXIL_NATIVE_HALF: call half @llvm.dx.normalize.f16(half
19+
// SPIR_NATIVE_HALF: call half @llvm.spv.normalize.f16(half
20+
// DXIL_NO_HALF: call float @llvm.dx.normalize.f32(float
21+
// SPIR_NO_HALF: call float @llvm.spv.normalize.f32(float
22+
// NATIVE_HALF: ret half
23+
// NO_HALF: ret float
24+
half test_normalize_half(half p0)
25+
{
26+
return normalize(p0);
27+
}
28+
// DXIL_NATIVE_HALF: define noundef <2 x half> @
29+
// SPIR_NATIVE_HALF: define spir_func noundef <2 x half> @
30+
// DXIL_NATIVE_HALF: call <2 x half> @llvm.dx.normalize.v2f16(<2 x half>
31+
// SPIR_NATIVE_HALF: call <2 x half> @llvm.spv.normalize.v2f16(<2 x half>
32+
// DXIL_NO_HALF: call <2 x float> @llvm.dx.normalize.v2f32(<2 x float>
33+
// SPIR_NO_HALF: call <2 x float> @llvm.spv.normalize.v2f32(<2 x float>
34+
// NATIVE_HALF: ret <2 x half> %hlsl.normalize
35+
// NO_HALF: ret <2 x float> %hlsl.normalize
36+
half2 test_normalize_half2(half2 p0)
37+
{
38+
return normalize(p0);
39+
}
40+
// DXIL_NATIVE_HALF: define noundef <3 x half> @
41+
// SPIR_NATIVE_HALF: define spir_func noundef <3 x half> @
42+
// DXIL_NATIVE_HALF: call <3 x half> @llvm.dx.normalize.v3f16(<3 x half>
43+
// SPIR_NATIVE_HALF: call <3 x half> @llvm.spv.normalize.v3f16(<3 x half>
44+
// DXIL_NO_HALF: call <3 x float> @llvm.dx.normalize.v3f32(<3 x float>
45+
// SPIR_NO_HALF: call <3 x float> @llvm.spv.normalize.v3f32(<3 x float>
46+
// NATIVE_HALF: ret <3 x half> %hlsl.normalize
47+
// NO_HALF: ret <3 x float> %hlsl.normalize
48+
half3 test_normalize_half3(half3 p0)
49+
{
50+
return normalize(p0);
51+
}
52+
// DXIL_NATIVE_HALF: define noundef <4 x half> @
53+
// SPIR_NATIVE_HALF: define spir_func noundef <4 x half> @
54+
// DXIL_NATIVE_HALF: call <4 x half> @llvm.dx.normalize.v4f16(<4 x half>
55+
// SPIR_NATIVE_HALF: call <4 x half> @llvm.spv.normalize.v4f16(<4 x half>
56+
// DXIL_NO_HALF: call <4 x float> @llvm.dx.normalize.v4f32(<4 x float>
57+
// SPIR_NO_HALF: call <4 x float> @llvm.spv.normalize.v4f32(<4 x float>
58+
// NATIVE_HALF: ret <4 x half> %hlsl.normalize
59+
// NO_HALF: ret <4 x float> %hlsl.normalize
60+
half4 test_normalize_half4(half4 p0)
61+
{
62+
return normalize(p0);
63+
}
64+
65+
// DXIL_CHECK: define noundef float @
66+
// SPIR_CHECK: define spir_func noundef float @
67+
// DXIL_CHECK: call float @llvm.dx.normalize.f32(float
68+
// SPIR_CHECK: call float @llvm.spv.normalize.f32(float
69+
// CHECK: ret float
70+
float test_normalize_float(float p0)
71+
{
72+
return normalize(p0);
73+
}
74+
// DXIL_CHECK: define noundef <2 x float> @
75+
// SPIR_CHECK: define spir_func noundef <2 x float> @
76+
// DXIL_CHECK: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
77+
// SPIR_CHECK: %hlsl.normalize = call <2 x float> @llvm.spv.normalize.v2f32(<2 x float>
78+
// CHECK: ret <2 x float> %hlsl.normalize
79+
float2 test_normalize_float2(float2 p0)
80+
{
81+
return normalize(p0);
82+
}
83+
// DXIL_CHECK: define noundef <3 x float> @
84+
// SPIR_CHECK: define spir_func noundef <3 x float> @
85+
// DXIL_CHECK: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
86+
// SPIR_CHECK: %hlsl.normalize = call <3 x float> @llvm.spv.normalize.v3f32(<3 x float>
87+
// CHECK: ret <3 x float> %hlsl.normalize
88+
float3 test_normalize_float3(float3 p0)
89+
{
90+
return normalize(p0);
91+
}
92+
// DXIL_CHECK: define noundef <4 x float> @
93+
// SPIR_CHECK: define spir_func noundef <4 x float> @
94+
// DXIL_CHECK: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
95+
// SPIR_CHECK: %hlsl.normalize = call <4 x float> @llvm.spv.normalize.v4f32(
96+
// CHECK: ret <4 x float> %hlsl.normalize
97+
float4 test_length_float4(float4 p0)
98+
{
99+
return normalize(p0);
100+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
2+
3+
void test_too_few_arg()
4+
{
5+
return __builtin_hlsl_normalize();
6+
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
7+
}
8+
9+
void test_too_many_arg(float2 p0)
10+
{
11+
return __builtin_hlsl_normalize(p0, p0);
12+
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
13+
}
14+
15+
bool builtin_bool_to_float_type_promotion(bool p1)
16+
{
17+
return __builtin_hlsl_normalize(p1);
18+
// expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
19+
}
20+
21+
bool builtin_normalize_int_to_float_promotion(int p1)
22+
{
23+
return __builtin_hlsl_normalize(p1);
24+
// expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
25+
}
26+
27+
bool2 builtin_normalize_int2_to_float2_promotion(int2 p1)
28+
{
29+
return __builtin_hlsl_normalize(p1);
30+
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
31+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType
5858
def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
5959
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
6060
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
61+
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
6162
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
6263
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
6364
}

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,6 @@ let TargetPrefix = "spv" in {
6464
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
6565
[IntrNoMem, IntrWillReturn] >;
6666
def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
67+
def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
6768
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
6869
}

llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
4343
case Intrinsic::dx_uclamp:
4444
case Intrinsic::dx_lerp:
4545
case Intrinsic::dx_length:
46+
case Intrinsic::dx_normalize:
4647
case Intrinsic::dx_sdot:
4748
case Intrinsic::dx_udot:
4849
return true;
@@ -229,6 +230,75 @@ static bool expandLog10Intrinsic(CallInst *Orig) {
229230
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
230231
}
231232

233+
static bool expandNormalizeIntrinsic(CallInst *Orig) {
234+
Value *X = Orig->getOperand(0);
235+
Type *Ty = Orig->getType();
236+
Type *EltTy = Ty->getScalarType();
237+
IRBuilder<> Builder(Orig->getParent());
238+
Builder.SetInsertPoint(Orig);
239+
240+
auto *XVec = dyn_cast<FixedVectorType>(Ty);
241+
if (!XVec) {
242+
if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
243+
const APFloat &fpVal = constantFP->getValueAPF();
244+
if (fpVal.isZero())
245+
report_fatal_error(Twine("Invalid input scalar: length is zero"),
246+
/* gen_crash_diag=*/false);
247+
}
248+
Value *Result = Builder.CreateFDiv(X, X);
249+
250+
Orig->replaceAllUsesWith(Result);
251+
Orig->eraseFromParent();
252+
return true;
253+
}
254+
255+
Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
256+
unsigned XVecSize = XVec->getNumElements();
257+
Value *DotProduct = nullptr;
258+
// use the dot intrinsic corresponding to the vector size
259+
switch (XVecSize) {
260+
case 1:
261+
report_fatal_error(Twine("Invalid input vector: length is zero"),
262+
/* gen_crash_diag=*/false);
263+
break;
264+
case 2:
265+
DotProduct = Builder.CreateIntrinsic(
266+
EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
267+
break;
268+
case 3:
269+
DotProduct = Builder.CreateIntrinsic(
270+
EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
271+
break;
272+
case 4:
273+
DotProduct = Builder.CreateIntrinsic(
274+
EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
275+
break;
276+
default:
277+
report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
278+
/* gen_crash_diag=*/false);
279+
}
280+
281+
// verify that the length is non-zero
282+
// (if the dot product is non-zero, then the length is non-zero)
283+
if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) {
284+
const APFloat &fpVal = constantFP->getValueAPF();
285+
if (fpVal.isZero())
286+
report_fatal_error(Twine("Invalid input vector: length is zero"),
287+
/* gen_crash_diag=*/false);
288+
}
289+
290+
Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
291+
ArrayRef<Value *>{DotProduct},
292+
nullptr, "dx.rsqrt");
293+
294+
Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
295+
Value *Result = Builder.CreateFMul(X, MultiplicandVec);
296+
297+
Orig->replaceAllUsesWith(Result);
298+
Orig->eraseFromParent();
299+
return true;
300+
}
301+
232302
static bool expandPowIntrinsic(CallInst *Orig) {
233303

234304
Value *X = Orig->getOperand(0);
@@ -314,6 +384,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
314384
return expandLerpIntrinsic(Orig);
315385
case Intrinsic::dx_length:
316386
return expandLengthIntrinsic(Orig);
387+
case Intrinsic::dx_normalize:
388+
return expandNormalizeIntrinsic(Orig);
317389
case Intrinsic::dx_sdot:
318390
case Intrinsic::dx_udot:
319391
return expandIntegerDot(Orig, F.getIntrinsicID());

0 commit comments

Comments
 (0)