Skip to content

[CoopVec] Add Linear Algebra common header with tests #7350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions tools/clang/lib/Headers/hlsl/dx/linalg.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// Header for linear algebra APIs.

#if __spirv__
#error "Cooperative vectors not (yet) supported for SPIRV"
#endif

#if ((__SHADER_TARGET_MAJOR > 6) || \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this header error if it is included when targeting older shader models?
What about targeting SPIRV?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to include the header when the shader is targeting older shader models. The user might have some preprocessor code to use these APIs if 6.9 is available, and to default to a fallback if not.
It is my understanding that whether we're targeting SPIRV or not doesn't matter, the builtins will be lowered to the correct IR regardless.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only supported for d3d / dxil (which is why everything is / will be in the dx namespace)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SPIRV path doesn't generate IR in DXC, it goes straight from the AST to SPIRV, which would likely cause the compiler either to crash or to emit a less-than helpful diagnostic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could add something like this to the top of the file?

#if __spirv__
#error "Cooperative matrix not (yet) supported for SPIRV"
#endif

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/matrix/vector

(__SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR >= 9)) && \
(__HLSL_VERSION >= 2021)

namespace dx {
namespace linalg {

// NOTE: can't be an enum class because we get this error:
// error: non-type template argument of type 'dx::linalg::DataType' is not
// an integral constant expression
//
enum DataType {
DATA_TYPE_SINT16 = 2, // ComponentType::I16
DATA_TYPE_UINT16 = 3, // ComponentType::U16
DATA_TYPE_SINT32 = 4, // ComponentType::I32
DATA_TYPE_UINT32 = 5, // ComponentType::U32
DATA_TYPE_FLOAT16 = 8, // ComponentType::F16
DATA_TYPE_FLOAT32 = 9, // ComponentType::F32
DATA_TYPE_SINT8_T4_PACKED = 17, // ComponentType::PackedS8x32
DATA_TYPE_UINT8_T4_PACKED = 18, // ComponentType::PackedU8x32
DATA_TYPE_UINT8 = 19, // ComponentType::U8
DATA_TYPE_SINT8 = 20, // ComponentType::I8
DATA_TYPE_FLOAT8_E4M3 = 21, // ComponentType::F8_E4M3
// (1 sign, 4 exp, 3 mantissa bits)
DATA_TYPE_FLOAT8_E5M2 = 22, // ComponentType::F8_E5M2
// (1 sign, 5 exp, 2 mantissa bits)
};

enum MatrixLayout {
MATRIX_LAYOUT_ROW_MAJOR = 0,
MATRIX_LAYOUT_COLUMN_MAJOR = 1,
MATRIX_LAYOUT_MUL_OPTIMAL = 2,
MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL = 3
};

//
// Helper for signedness
//
namespace details {
template <typename T> bool IsUnsigned() { return false; }

#ifdef __HLSL_ENABLE_16_BIT
template <> bool IsUnsigned<uint16_t>() { return true; }
#endif

template <> bool IsUnsigned<uint32_t>() { return true; }
template <> bool IsUnsigned<uint64_t>() { return true; }
} // namespace details

//
// (RW)MatrixRef
//

template <typename BufferTy, DataType DT, uint M, uint K, MatrixLayout ML,
bool Transpose>
struct MatrixRefImpl {
BufferTy Buffer;
uint StartOffset;
uint Stride;
};

template <DataType DT, uint M, uint K, MatrixLayout ML, bool Transpose = false>
using MatrixRef = MatrixRefImpl<ByteAddressBuffer, DT, M, K, ML, Transpose>;

template <DataType DT, uint M, uint K, MatrixLayout ML, bool Transpose = false>
using RWMatrixRef = MatrixRefImpl<RWByteAddressBuffer, DT, M, K, ML, Transpose>;

//
// (RW)VectorRef
//

template <typename BufferTy, DataType DT> struct VectorRefImpl {
BufferTy Buffer;
uint StartOffset;
};

template <DataType DT> using VectorRef = VectorRefImpl<ByteAddressBuffer, DT>;

template <DataType DT>
using RWVectorRef = VectorRefImpl<RWByteAddressBuffer, DT>;

//
// Vector
//

template <typename T, int N, DataType DT> struct InterpretedVector {
vector<T, N> Data;
};

template <DataType DT, typename T, int N>
InterpretedVector<T, N, DT> MakeInterpretedVector(vector<T, N> Vec) {
InterpretedVector<T, N, DT> IV = {Vec};
return IV;
}

//
// Mul
//

template <typename OutputElTy, typename InputElTy, int InputElCount,
typename MatrixBufferTy, DataType InputDT, DataType MatrixDT,
uint MatrixM, uint MatrixK, MatrixLayout MatrixLayout,
bool MatrixTranspose>
vector<OutputElTy, MatrixM>
Mul(MatrixRefImpl<MatrixBufferTy, MatrixDT, MatrixM, MatrixK, MatrixLayout,
MatrixTranspose>
Matrix,
InterpretedVector<InputElTy, InputElCount, InputDT> InputVector) {

vector<OutputElTy, MatrixM> OutputVector;

__builtin_MatVecMul(
/*out*/ OutputVector, details::IsUnsigned<OutputElTy>(), InputVector.Data,
details::IsUnsigned<InputElTy>(), InputDT, Matrix.Buffer,
Matrix.StartOffset, MatrixDT, MatrixM, MatrixK, MatrixLayout,
MatrixTranspose, Matrix.Stride);

return OutputVector;
}

//
// MulAdd
//

template <typename OutputElTy, typename InputElTy, int InputElCount,
typename MatrixBufferTy, DataType InputDT, DataType MatrixDT,
uint MatrixM, uint MatrixK, MatrixLayout MatrixLayout,
bool MatrixTranspose, typename BiasVectorBufferTy,
DataType BiasVectorDT>
vector<OutputElTy, MatrixM>
MulAdd(MatrixRefImpl<MatrixBufferTy, MatrixDT, MatrixM, MatrixK, MatrixLayout,
MatrixTranspose>
Matrix,
InterpretedVector<InputElTy, InputElCount, InputDT> InputVector,
VectorRefImpl<BiasVectorBufferTy, BiasVectorDT> BiasVector) {

vector<OutputElTy, MatrixM> OutputVector;

__builtin_MatVecMulAdd(
/*out*/ OutputVector, details::IsUnsigned<OutputElTy>(), InputVector.Data,
details::IsUnsigned<InputElTy>(), InputDT, Matrix.Buffer,
Matrix.StartOffset, MatrixDT, MatrixM, MatrixK, MatrixLayout,
MatrixTranspose, Matrix.Stride, BiasVector.Buffer, BiasVector.StartOffset,
BiasVectorDT);

return OutputVector;
}

//
// OuterProductAccumulate
//

template <typename ElTy, int MatrixM, int MatrixN, DataType MatrixDT,
MatrixLayout MatrixLayout>
void OuterProductAccumulate(
vector<ElTy, MatrixM> InputVector1, vector<ElTy, MatrixN> InputVector2,
RWMatrixRef<MatrixDT, MatrixM, MatrixN, MatrixLayout, false> Matrix) {
__builtin_OuterProductAccumulate(InputVector1, InputVector2, Matrix.Buffer,
Matrix.StartOffset, MatrixDT, MatrixLayout,
Matrix.Stride);
}

//
// VectorAccumulate
//

template <typename ElTy, int ElCount>
void VectorAccumulate(vector<ElTy, ElCount> InputVector,
RWByteAddressBuffer Buffer, uint Offset) {
__builtin_VectorAccumulate(InputVector, Buffer, Offset);
}

} // namespace linalg
} // namespace dx

#endif // SM 6.9 check and HV version check
40 changes: 40 additions & 0 deletions tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-mul.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types %s | FileCheck %s

#include <dx/linalg.h>

ByteAddressBuffer Buf;

export float4 Test1(vector<float, 4> Input) {
using namespace dx::linalg;

MatrixRef<DATA_TYPE_FLOAT16, 4, 4, MATRIX_LAYOUT_MUL_OPTIMAL, true> Matrix = {
Buf, 0, 0};

// CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMul.v4f32.v4f32(i32 305, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle %{{.+}}, i32 0, i32 8, i32 4, i32 4, i32 2, i1 true, i32 0, i1 false)
return Mul<float>(
Matrix, MakeInterpretedVector<DATA_TYPE_FLOAT16>(Input));
}

export vector<float, 8> Test2(vector<uint8_t4_packed, 6> Input) {
using namespace dx::linalg;

MatrixRef<DATA_TYPE_UINT8, 8, 6 * 4, MATRIX_LAYOUT_MUL_OPTIMAL> Matrix = {
Buf, 0, 0};

// note the stride argument is dropped.
// CHECK: %{{.+}} = call <8 x float> @dx.op.matVecMul.v8f32.v6f32(i32 305, <6 x float> %{{.+}}, i1 false, i32 18, %dx.types.Handle %{{.+}}, i32 0, i32 19, i32 8, i32 24, i32 2, i1 false, i32 0, i1 false)
return Mul<float>(Matrix,
MakeInterpretedVector<DATA_TYPE_UINT8_T4_PACKED>(Input));
}

// test that "stride" isn't ignored in non-optimal layouts
export vector<float, 8> Test3(vector<uint8_t4_packed, 6> Input) {
using namespace dx::linalg;

MatrixRef<DATA_TYPE_UINT8, 8, 6 * 4, MATRIX_LAYOUT_ROW_MAJOR> Matrix = {
Buf, 0, 6 * 4 * 8};

// CHECK: %{{.+}} = call <8 x float> @dx.op.matVecMul.v8f32.v6f32(i32 305, <6 x float> %{{.+}}, i1 false, i32 18, %dx.types.Handle %{{.+}}, i32 0, i32 19, i32 8, i32 24, i32 0, i1 false, i32 192, i1 false)
return Mul<float>(Matrix,
MakeInterpretedVector<DATA_TYPE_UINT8_T4_PACKED>(Input));
}
90 changes: 90 additions & 0 deletions tools/clang/test/CodeGenDXIL/hlsl/linalg/mat-vec-muladd.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s | FileCheck %s

#include <dx/linalg.h>

ByteAddressBuffer Buf;

export float4 Test1(float4 input) {
using namespace dx::linalg;

MatrixRef<DATA_TYPE_FLOAT16, 4, 4, MATRIX_LAYOUT_MUL_OPTIMAL> matrix = {Buf,
0, 0};
VectorRef<DATA_TYPE_FLOAT16> biasVector = {Buf, 256};

InterpretedVector<float, 4, DATA_TYPE_FLOAT16> theVector = {input};

// CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle [[RES:%.+]], i32 0, i32 8, i32 4, i32 4, i32 2, i1 false, i32 0, %dx.types.Handle [[RES]], i32 256, i32 8, i1 false)
return MulAdd<float>(
matrix, theVector,
biasVector);
}

export float4 Test2(float4 input) {
using namespace dx::linalg;

MatrixRef<DATA_TYPE_FLOAT16, 4, 4, MATRIX_LAYOUT_MUL_OPTIMAL, true> matrix = {
Buf, 0, 0};
VectorRef<DATA_TYPE_FLOAT16> biasVector = {Buf, 256};

InterpretedVector<float, 4, DATA_TYPE_FLOAT16> theVector = {input};

// CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle [[RES:%.+]], i32 0, i32 8, i32 4, i32 4, i32 2, i1 true, i32 0, %dx.types.Handle [[RES]], i32 256, i32 8, i1 false)
return MulAdd<float>(
matrix, theVector,
biasVector);
}

export float4 Test3(float4 input) {
using namespace dx::linalg;

MatrixRef<DATA_TYPE_FLOAT16, 4, 4, MATRIX_LAYOUT_MUL_OPTIMAL, true> matrix = {
Buf, 0, 0};
VectorRef<DATA_TYPE_FLOAT16> biasVector = {Buf, 256};

// CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle [[RES:%.+]], i32 0, i32 8, i32 4, i32 4, i32 2, i1 true, i32 0, %dx.types.Handle [[RES]], i32 256, i32 8, i1 false)
return MulAdd<float>(
matrix, MakeInterpretedVector<DATA_TYPE_FLOAT16>(input),
biasVector);
}

namespace ProposalExample {

ByteAddressBuffer model;

vector<float, 3> ApplyNeuralMaterial(vector<half, 8> inputVector) {
using namespace dx::linalg;

MatrixRef<DATA_TYPE_FLOAT8_E4M3, 32, 8, MATRIX_LAYOUT_MUL_OPTIMAL> matrix0 = {
model, 0, 0};

VectorRef<DATA_TYPE_FLOAT16> biasVector0 = {model, 1024};

MatrixRef<DATA_TYPE_FLOAT8_E4M3, 32, 32, MATRIX_LAYOUT_MUL_OPTIMAL> matrix1 =
{model, 2048, 0};

VectorRef<DATA_TYPE_FLOAT16> biasVector1 = {model, 3072};

MatrixRef<DATA_TYPE_FLOAT8_E4M3, 3, 32, MATRIX_LAYOUT_MUL_OPTIMAL> matrix2 = {
model, 4096, 0};

VectorRef<DATA_TYPE_FLOAT16> biasVector2 = {model, 5120};

vector<half, 32> layer0 = MulAdd<half>(
matrix0, MakeInterpretedVector<DATA_TYPE_FLOAT8_E4M3>(inputVector),
biasVector0);
layer0 = max(layer0, 0);

vector<half, 32> layer1 = MulAdd<half>(
matrix1, MakeInterpretedVector<DATA_TYPE_FLOAT8_E4M3>(layer0),
biasVector1);
layer1 = max(layer1, 0);

vector<float, 3> output = MulAdd<float>(
matrix2, MakeInterpretedVector<DATA_TYPE_FLOAT8_E4M3>(layer1),
biasVector2);
output = exp(output);

return output;
}

} // namespace ProposalExample
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types %s | FileCheck %s

#include <dx/linalg.h>

RWByteAddressBuffer RWBuf;

export void Test4(vector<half, 128> Input1, vector<half, 64> Input2) {
using namespace dx::linalg;

RWMatrixRef<DATA_TYPE_FLOAT16, 128, 64, MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL>
matrix = {RWBuf, 0, 0};

// CHECK: call void @dx.op.outerProductAccumulate.v128f16.v64f16(i32 307, <128 x half> %{{.+}}, <64 x half> %{{.+}}, %dx.types.Handle %{{.+}}, i32 0, i32 8, i32 3, i32 0)

OuterProductAccumulate(Input1, Input2, matrix);
}
14 changes: 14 additions & 0 deletions tools/clang/test/CodeGenDXIL/hlsl/linalg/vectoraccumulate.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s | FileCheck %s

#include <dx/linalg.h>

RWByteAddressBuffer RWBuf;

export void Test5(vector<half, 128> Input) {
using namespace dx::linalg;

RWBuf.Store<vector<half, 128> >(0, Input);

// CHECK: call void @dx.op.vectorAccumulate.v128f32(i32 308, <128 x float> %{{.*}}, %dx.types.Handle %{{.*}}, i32 0)
VectorAccumulate(Input, RWBuf, 0);
}
33 changes: 33 additions & 0 deletions tools/clang/test/SemaHLSL/hlsl/linalg/make-interp-vec-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s -verify

#include <dx/linalg.h>
ByteAddressBuffer Buf;

export float4 Test1(vector<float, 4> Input) {
using namespace dx::linalg;

MatrixRef<DATA_TYPE_UINT16, 4, 4, MATRIX_LAYOUT_MUL_OPTIMAL, true> Matrix = {
Buf, 0, 0};

// expected-error@+3{{no matching function for call to 'MakeInterpretedVector'}}
// expected-note@dx/linalg.h:97{{candidate template ignored: invalid explicitly-specified argument for template parameter 'DT'}}
return Mul<float>(
Matrix, MakeInterpretedVector<2>(Input));
}

enum DataType {
DATA_TYPE_InvalidType = 40
};

export float4 Test2(vector<float, 4> Input) {
using namespace dx::linalg;

MatrixRef<DATA_TYPE_UINT16, 4, 4, MATRIX_LAYOUT_MUL_OPTIMAL, true> Matrix = {
Buf, 0, 0};

// expected-error@+3{{no matching function for call to 'MakeInterpretedVector'}}
// expected-note@dx/linalg.h:97{{candidate template ignored: invalid explicitly-specified argument for template parameter 'DT'}}
return Mul<float>(
Matrix, MakeInterpretedVector<DATA_TYPE_InvalidType>(Input));
}

16 changes: 16 additions & 0 deletions tools/clang/test/SemaHLSL/hlsl/linalg/mat-vec-mul-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s -verify

#include <dx/linalg.h>

ByteAddressBuffer Buf;

vector<float, 128> MixUpVectorAndMatrixArguments(vector<float, 128> Input) {
using namespace dx::linalg;

MatrixRef<DATA_TYPE_FLOAT16, 128, 128, MATRIX_LAYOUT_MUL_OPTIMAL> Matrix = {
Buf, 0, 0};

// expected-error@+2{{no matching function for call to 'Mul'}}
// expected-note@dx/linalg.h:111{{candidate template ignored: could not match 'MatrixRefImpl' against 'InterpretedVector'}}
return Mul<float>(MakeInterpretedVector<DATA_TYPE_FLOAT16>(Input), Matrix);
}
Loading