diff --git a/tools/clang/lib/Headers/hlsl/dx/linalg.h b/tools/clang/lib/Headers/hlsl/dx/linalg.h new file mode 100644 index 0000000000..51e662bbc9 --- /dev/null +++ b/tools/clang/lib/Headers/hlsl/dx/linalg.h @@ -0,0 +1,182 @@ +// Header for linear algebra APIs. + +#if __spirv__ +#error "Cooperative vectors not (yet) supported for SPIRV" +#endif + +#if ((__SHADER_TARGET_MAJOR > 6) || \ + (__SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR >= 9)) && \ + (__HLSL_VERSION >= 2021) + +namespace dx { +namespace linalg { + +// NOTE: can't be an enum class because we get this error: +// error: non-type template argument of type 'dx::linalg::DataType' is not +// an integral constant expression +// +enum DataType { + DATA_TYPE_SINT16 = 2, // ComponentType::I16 + DATA_TYPE_UINT16 = 3, // ComponentType::U16 + DATA_TYPE_SINT32 = 4, // ComponentType::I32 + DATA_TYPE_UINT32 = 5, // ComponentType::U32 + DATA_TYPE_FLOAT16 = 8, // ComponentType::F16 + DATA_TYPE_FLOAT32 = 9, // ComponentType::F32 + DATA_TYPE_SINT8_T4_PACKED = 17, // ComponentType::PackedS8x32 + DATA_TYPE_UINT8_T4_PACKED = 18, // ComponentType::PackedU8x32 + DATA_TYPE_UINT8 = 19, // ComponentType::U8 + DATA_TYPE_SINT8 = 20, // ComponentType::I8 + DATA_TYPE_FLOAT8_E4M3 = 21, // ComponentType::F8_E4M3 + // (1 sign, 4 exp, 3 mantissa bits) + DATA_TYPE_FLOAT8_E5M2 = 22, // ComponentType::F8_E5M2 + // (1 sign, 5 exp, 2 mantissa bits) +}; + +enum MatrixLayout { + MATRIX_LAYOUT_ROW_MAJOR = 0, + MATRIX_LAYOUT_COLUMN_MAJOR = 1, + MATRIX_LAYOUT_MUL_OPTIMAL = 2, + MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL = 3 +}; + +// +// Helper for signedness +// +namespace details { +template bool IsUnsigned() { return false; } + +#ifdef __HLSL_ENABLE_16_BIT +template <> bool IsUnsigned() { return true; } +#endif + +template <> bool IsUnsigned() { return true; } +template <> bool IsUnsigned() { return true; } +} // namespace details + +// +// (RW)MatrixRef +// + +template +struct MatrixRefImpl { + BufferTy Buffer; + uint StartOffset; + uint Stride; +}; + +template +using MatrixRef = MatrixRefImpl; + +template +using RWMatrixRef = MatrixRefImpl; + +// +// (RW)VectorRef +// + +template struct VectorRefImpl { + BufferTy Buffer; + uint StartOffset; +}; + +template using VectorRef = VectorRefImpl; + +template +using RWVectorRef = VectorRefImpl; + +// +// Vector +// + +template struct InterpretedVector { + vector Data; +}; + +template +InterpretedVector MakeInterpretedVector(vector Vec) { + InterpretedVector IV = {Vec}; + return IV; +} + +// +// Mul +// + +template +vector +Mul(MatrixRefImpl + Matrix, + InterpretedVector InputVector) { + + vector OutputVector; + + __builtin_MatVecMul( + /*out*/ OutputVector, details::IsUnsigned(), InputVector.Data, + details::IsUnsigned(), InputDT, Matrix.Buffer, + Matrix.StartOffset, MatrixDT, MatrixM, MatrixK, MatrixLayout, + MatrixTranspose, Matrix.Stride); + + return OutputVector; +} + +// +// MulAdd +// + +template +vector +MulAdd(MatrixRefImpl + Matrix, + InterpretedVector InputVector, + VectorRefImpl BiasVector) { + + vector OutputVector; + + __builtin_MatVecMulAdd( + /*out*/ OutputVector, details::IsUnsigned(), InputVector.Data, + details::IsUnsigned(), InputDT, Matrix.Buffer, + Matrix.StartOffset, MatrixDT, MatrixM, MatrixK, MatrixLayout, + MatrixTranspose, Matrix.Stride, BiasVector.Buffer, BiasVector.StartOffset, + BiasVectorDT); + + return OutputVector; +} + +// +// OuterProductAccumulate +// + +template +void OuterProductAccumulate( + vector InputVector1, vector InputVector2, + RWMatrixRef Matrix) { + __builtin_OuterProductAccumulate(InputVector1, InputVector2, Matrix.Buffer, + Matrix.StartOffset, MatrixDT, MatrixLayout, + Matrix.Stride); +} + +// +// VectorAccumulate +// + +template +void VectorAccumulate(vector InputVector, + RWByteAddressBuffer Buffer, uint Offset) { + __builtin_VectorAccumulate(InputVector, Buffer, Offset); +} + +} // namespace linalg +} // namespace dx + +#endif // SM 6.9 check and HV version check diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-mul.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-mul.hlsl new file mode 100644 index 0000000000..141801c71c --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-mul.hlsl @@ -0,0 +1,40 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types %s | FileCheck %s + +#include + +ByteAddressBuffer Buf; + +export float4 Test1(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMul.v4f32.v4f32(i32 305, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle %{{.+}}, i32 0, i32 8, i32 4, i32 4, i32 2, i1 true, i32 0, i1 false) + return Mul( + Matrix, MakeInterpretedVector(Input)); +} + +export vector Test2(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // note the stride argument is dropped. + // CHECK: %{{.+}} = call <8 x float> @dx.op.matVecMul.v8f32.v6f32(i32 305, <6 x float> %{{.+}}, i1 false, i32 18, %dx.types.Handle %{{.+}}, i32 0, i32 19, i32 8, i32 24, i32 2, i1 false, i32 0, i1 false) + return Mul(Matrix, + MakeInterpretedVector(Input)); +} + +// test that "stride" isn't ignored in non-optimal layouts +export vector Test3(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 6 * 4 * 8}; + + // CHECK: %{{.+}} = call <8 x float> @dx.op.matVecMul.v8f32.v6f32(i32 305, <6 x float> %{{.+}}, i1 false, i32 18, %dx.types.Handle %{{.+}}, i32 0, i32 19, i32 8, i32 24, i32 0, i1 false, i32 192, i1 false) + return Mul(Matrix, + MakeInterpretedVector(Input)); +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-muladd.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-muladd.hlsl new file mode 100644 index 0000000000..c19e601904 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-muladd.hlsl @@ -0,0 +1,90 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s | FileCheck %s + +#include + +ByteAddressBuffer Buf; + +export float4 Test1(float4 input) { + using namespace dx::linalg; + + MatrixRef matrix = {Buf, + 0, 0}; + VectorRef biasVector = {Buf, 256}; + + InterpretedVector theVector = {input}; + + // CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle [[RES:%.+]], i32 0, i32 8, i32 4, i32 4, i32 2, i1 false, i32 0, %dx.types.Handle [[RES]], i32 256, i32 8, i1 false) + return MulAdd( + matrix, theVector, + biasVector); +} + +export float4 Test2(float4 input) { + using namespace dx::linalg; + + MatrixRef matrix = { + Buf, 0, 0}; + VectorRef biasVector = {Buf, 256}; + + InterpretedVector theVector = {input}; + + // CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle [[RES:%.+]], i32 0, i32 8, i32 4, i32 4, i32 2, i1 true, i32 0, %dx.types.Handle [[RES]], i32 256, i32 8, i1 false) + return MulAdd( + matrix, theVector, + biasVector); +} + +export float4 Test3(float4 input) { + using namespace dx::linalg; + + MatrixRef matrix = { + Buf, 0, 0}; + VectorRef biasVector = {Buf, 256}; + + // CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle [[RES:%.+]], i32 0, i32 8, i32 4, i32 4, i32 2, i1 true, i32 0, %dx.types.Handle [[RES]], i32 256, i32 8, i1 false) + return MulAdd( + matrix, MakeInterpretedVector(input), + biasVector); +} + +namespace ProposalExample { + +ByteAddressBuffer model; + +vector ApplyNeuralMaterial(vector inputVector) { + using namespace dx::linalg; + + MatrixRef matrix0 = { + model, 0, 0}; + + VectorRef biasVector0 = {model, 1024}; + + MatrixRef matrix1 = + {model, 2048, 0}; + + VectorRef biasVector1 = {model, 3072}; + + MatrixRef matrix2 = { + model, 4096, 0}; + + VectorRef biasVector2 = {model, 5120}; + + vector layer0 = MulAdd( + matrix0, MakeInterpretedVector(inputVector), + biasVector0); + layer0 = max(layer0, 0); + + vector layer1 = MulAdd( + matrix1, MakeInterpretedVector(layer0), + biasVector1); + layer1 = max(layer1, 0); + + vector output = MulAdd( + matrix2, MakeInterpretedVector(layer1), + biasVector2); + output = exp(output); + + return output; +} + +} // namespace ProposalExample diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/outerproductaccumulate.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/outerproductaccumulate.hlsl new file mode 100644 index 0000000000..eda15c66f6 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/outerproductaccumulate.hlsl @@ -0,0 +1,16 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types %s | FileCheck %s + +#include + +RWByteAddressBuffer RWBuf; + +export void Test4(vector Input1, vector Input2) { + using namespace dx::linalg; + + RWMatrixRef + matrix = {RWBuf, 0, 0}; + + // CHECK: call void @dx.op.outerProductAccumulate.v128f16.v64f16(i32 307, <128 x half> %{{.+}}, <64 x half> %{{.+}}, %dx.types.Handle %{{.+}}, i32 0, i32 8, i32 3, i32 0) + + OuterProductAccumulate(Input1, Input2, matrix); +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/vectoraccumulate.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/vectoraccumulate.hlsl new file mode 100644 index 0000000000..9157156f10 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/vectoraccumulate.hlsl @@ -0,0 +1,14 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s | FileCheck %s + +#include + +RWByteAddressBuffer RWBuf; + +export void Test5(vector Input) { + using namespace dx::linalg; + + RWBuf.Store >(0, Input); + + // CHECK: call void @dx.op.vectorAccumulate.v128f32(i32 308, <128 x float> %{{.*}}, %dx.types.Handle %{{.*}}, i32 0) + VectorAccumulate(Input, RWBuf, 0); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/make-interp-vec-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/make-interp-vec-errors.hlsl new file mode 100644 index 0000000000..9f2793d417 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/make-interp-vec-errors.hlsl @@ -0,0 +1,33 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s -verify + +#include +ByteAddressBuffer Buf; + +export float4 Test1(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // expected-error@+3{{no matching function for call to 'MakeInterpretedVector'}} + // expected-note@dx/linalg.h:97{{candidate template ignored: invalid explicitly-specified argument for template parameter 'DT'}} + return Mul( + Matrix, MakeInterpretedVector<2>(Input)); +} + +enum DataType { + DATA_TYPE_InvalidType = 40 +}; + +export float4 Test2(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // expected-error@+3{{no matching function for call to 'MakeInterpretedVector'}} + // expected-note@dx/linalg.h:97{{candidate template ignored: invalid explicitly-specified argument for template parameter 'DT'}} + return Mul( + Matrix, MakeInterpretedVector(Input)); +} + diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-errors.hlsl new file mode 100644 index 0000000000..2d5a11e83e --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-errors.hlsl @@ -0,0 +1,16 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s -verify + +#include + +ByteAddressBuffer Buf; + +vector MixUpVectorAndMatrixArguments(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // expected-error@+2{{no matching function for call to 'Mul'}} + // expected-note@dx/linalg.h:111{{candidate template ignored: could not match 'MatrixRefImpl' against 'InterpretedVector'}} + return Mul(MakeInterpretedVector(Input), Matrix); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-transpose-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-transpose-errors.hlsl new file mode 100644 index 0000000000..2018acafab --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-transpose-errors.hlsl @@ -0,0 +1,30 @@ +// XFAIL: * +// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types %s -verify + +#include + +ByteAddressBuffer Buf; + +export float4 Test1(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // PREVIEW CHECK TODO: + // expected-error@+1{{something about transposing not supported for rowmajor / colmajor layouts}} + return Mul( + Matrix, MakeInterpretedVector(Input)); +} + +export vector Test2(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // PREVIEW CHECK TODO: + // expected-error@+1{{something about transposing not supported for rowmajor / colmajor layouts}} + return Mul(Matrix, + MakeInterpretedVector(Input)); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-muladd-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-muladd-errors.hlsl new file mode 100644 index 0000000000..f444f81c3a --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-muladd-errors.hlsl @@ -0,0 +1,16 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s -verify + +#include + +ByteAddressBuffer Buf; + +vector MixUpVectorAndMatrixArguments(vector Input) { + using namespace dx::linalg; + + MatrixRef Matrix = { + Buf, 0, 0}; + + // expected-error@+2{{no matching function for call to 'MulAdd'}} + // expected-note@dx/linalg.h:137{{candidate template ignored: could not match 'MatrixRefImpl' against 'InterpretedVector'}} + return MulAdd(MakeInterpretedVector(Input), Matrix, MakeInterpretedVector(Input)); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-errors.hlsl new file mode 100644 index 0000000000..6f503b367b --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-errors.hlsl @@ -0,0 +1,44 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types %s -verify + +#include + +RWByteAddressBuffer RWBuf; + +// test for inputs of different size +export void Test4(vector Input1, vector Input2) { + using namespace dx::linalg; + + RWMatrixRef + matrix = {RWBuf, 0, 0}; + + // expected-error@+3{{no matching function for call to 'OuterProductAccumulate'}} + // expected-note@dx/linalg.h:161{{candidate template ignored: could not match 0 against 1}} + + OuterProductAccumulate(Input1, Input2, matrix); +} + +// now test for an error when element types differ +export void Test5(vector Input1, vector Input2) { + using namespace dx::linalg; + + RWMatrixRef + matrix = {RWBuf, 0, 0}; + + // expected-error@+3{{no matching function for call to 'OuterProductAccumulate'}} + // expected-note@dx/linalg.h:161{{candidate template ignored: could not match 0 against 1}} + + OuterProductAccumulate(Input1, Input2, matrix); +} + +// now test for an error when matrix transpose parameter is true +export void Test4(vector Input1, vector Input2) { + using namespace dx::linalg; + + RWMatrixRef + matrix = {RWBuf, 0, 0}; + + // expected-error@+3{{no matching function for call to 'OuterProductAccumulate'}} + // expected-note@dx/linalg.h:161{{candidate template ignored: deduced conflicting types for parameter 'ElTy' ('int' vs. 'unsigned int')}} + + OuterProductAccumulate(Input1, Input2, matrix); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-spirv-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-spirv-errors.hlsl new file mode 100644 index 0000000000..0213103926 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/outerproductaccumulate-spirv-errors.hlsl @@ -0,0 +1,19 @@ +// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types -spirv %s -verify + +// Tests that the header file cannot be included for spirv compilations +// This is a copy of \tools\clang\test\CodeGenDXIL\hlsl\linalg\outerproductaccumulate.hlsl +// except that spirv is targeted + +// expected-error@dx/linalg.h:4{{Cooperative vectors not (yet) supported for SPIRV}} +#include + +RWByteAddressBuffer RWBuf; + +export void Test4(vector Input1, vector Input2) { + using namespace dx::linalg; + + RWMatrixRef + matrix = {RWBuf, 0, 0}; + + OuterProductAccumulate(Input1, Input2, matrix); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/vectoraccumulate-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/vectoraccumulate-errors.hlsl new file mode 100644 index 0000000000..4c8ae6f049 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/vectoraccumulate-errors.hlsl @@ -0,0 +1,16 @@ +// XFAIL: * +// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s | FileCheck %s + +#include + +RWByteAddressBuffer RWBuf; + +export void Test5(vector Input) { + using namespace dx::linalg; + + RWBuf.Store >(0, Input); + + // PREVIEW CHECK TODO: + // CHECK: Something about an error due to illegal conversions + VectorAccumulate(Input, RWBuf, 0); +}