Skip to content

Commit b95cf18

Browse files
metal lowbit kernels: qmv_fast optimization (#2167)
1 parent 2c901b3 commit b95cf18

19 files changed

+569
-188
lines changed

torchao/experimental/kernels/mps/metal.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,6 @@
2121

2222
- func: int7mm
2323
file: int7mm.metal
24+
25+
- func: qmv_fast
26+
file: qmv_fast.metal

torchao/experimental/kernels/mps/metal/int1mm.metal

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ using namespace metal;
1111
*
1212
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
1313
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (K / 8)
14-
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
15-
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
14+
* @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups
15+
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups
1616
* @param[outputData] M x N output tensor of floating point dtype (same as input)
1717
* @param[sizes] The sizes involved in the order: M, K, N
1818
*
@@ -29,6 +29,7 @@ kernel void int1pack_mm(
2929
uint2 thread_index [[thread_position_in_grid]]) {
3030
const uint K = sizes.y;
3131
const uint N = sizes.z;
32+
const uint num_groups = (K + groupSize - 1) / groupSize;
3233
const uint m = thread_index.y; // 0..M-1
3334
const uint n = thread_index.x; // 0..N-1
3435
const uint32_t k_block = (K + groupSize - 1) / groupSize;
@@ -38,8 +39,8 @@ kernel void int1pack_mm(
3839
float rc = 0.0;
3940
uint k = 0;
4041
for (uint32_t kb = 0; kb < k_block ; kb ++) {
41-
const float scale = float(scales[kb * N + n]);
42-
const float zero = float(zeros[kb * N + n]);
42+
const float scale = float(scales[n * num_groups + kb]);
43+
const float zero = float(zeros[n * num_groups + kb]);
4344
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
4445
const auto a_val0 = float(A_ptr[k + 0]);
4546
const auto a_val1 = float(A_ptr[k + 1]);

torchao/experimental/kernels/mps/metal/int2mm_opt.metal

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@ using namespace metal;
2626
@param [in] B is weight matrix of size M x K. Each byte contains 4 2-bit
2727
values, along K dim, packed together.
2828
@param [in] scales_ptr is scales ptr corresponding each
29-
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
29+
output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output
3030
channels.
3131
@param [in] zeros_ptr is zero points corresponding each
32-
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
32+
output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output
3333
channels.
34-
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output
3534
@param [out] output_data is output matrix of size M x N.
3635
@param [in] sizes array contains values of M, K and N.
3736
@param [in] thread_index is global thread id.
@@ -51,6 +50,7 @@ kernel void int2pack_mm(constant T *A [[buffer(0)]],
5150
constexpr uint k_pack_factor = 4;
5251
const uint K = sizes.y;
5352
const uint N = sizes.z;
53+
const uint num_groups = (K + group_size - 1) / group_size;
5454
uint n = thread_index.x; // 0..N/4-1
5555
uint m = thread_index.z; // 0..M
5656
n = n / threads_per_channel;
@@ -75,13 +75,18 @@ kernel void int2pack_mm(constant T *A [[buffer(0)]],
7575
// Find specific group to which channels handled by this thread
7676
// belong.
7777
uint k_block_index = k / group_size;
78-
uint scales_group_offset = (k_block_index * N + n);
78+
uint scales_group_offset = (n * num_groups + k_block_index);
7979

8080
vecT scales =
81-
(reinterpret_cast<constant vecT *>(scales_ptr + scales_group_offset))[0];
82-
// Adding zero point results in 10% perf penalty.
81+
vecT(scales_ptr[scales_group_offset],
82+
scales_ptr[scales_group_offset + num_groups],
83+
scales_ptr[scales_group_offset + 2 * num_groups],
84+
scales_ptr[scales_group_offset + 3 * num_groups]);
8385
vecT zeros =
84-
(reinterpret_cast<constant vecT *>(zeros_ptr + scales_group_offset))[0];
86+
vecT(zeros_ptr[scales_group_offset],
87+
zeros_ptr[scales_group_offset + num_groups],
88+
zeros_ptr[scales_group_offset + 2 * num_groups],
89+
zeros_ptr[scales_group_offset + 3 * num_groups]);
8590
float4 zeros_float = float4(zeros);
8691

8792
float4 a_val = float4(A_ptr[k / 4]);

torchao/experimental/kernels/mps/metal/int3mm_opt.metal

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,23 @@
88
using namespace metal;
99

1010
inline void unpack_3bit(const uchar3 b, thread float* w) {
11-
w[0] = float(((b[0] & 1) << 2) | (b[1] & 3));
12-
w[1] = float(((b[0] & 2) << 1) | ((b[1] & 12) >> 2));
13-
w[2] = float((b[0] & 4) | ((b[1] & 48) >> 4));
14-
w[3] = float(((b[0] & 8) >> 1) | ((b[1] & 192) >> 6));
15-
16-
w[4] = float(((b[0] & 16) >> 2) | (b[2] & 3));
17-
w[5] = float(((b[0] & 32) >> 3) | ((b[2] & 12) >> 2));
18-
w[6] = float(((b[0] & 64) >> 4) | ((b[2] & 48) >> 4));
19-
w[7] = float(((b[0] & 128) >> 5) | ((b[2] & 192) >> 6));
11+
w[0] = float(b[0] & 0x07);
12+
w[1] = float((b[0] & 0x38) >> 3);
13+
w[2] = float(((b[0] & 0xc0) >> 6) | ((b[1] & 0x01) << 2));
14+
w[3] = float((b[1] & 0x0e) >> 1);
15+
w[4] = float((b[1] & 0x70) >> 4);
16+
w[5] = float(((b[1] & 0x80) >> 7) | ((b[2] & 0x03) << 1));
17+
w[6] = float((b[2] & 0x1c) >> 2);
18+
w[7] = float((b[2] & 0xe0) >> 5);
2019
}
2120

2221
/**
2322
* 3-Bit Quantized Linear.
2423
*
2524
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
2625
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8)
27-
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
28-
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
26+
* @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups
27+
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups
2928
* @param[outputData] M x N output tensor of floating point dtype (same as input)
3029
* @param[sizes] The sizes involved in the order: M, K, N
3130
*
@@ -45,6 +44,7 @@ kernel void int3pack_mm(constant T *A [[buffer(0)]],
4544
constexpr uint k_pack_factor = 8;
4645
const uint K = sizes.y;
4746
const uint N = sizes.z;
47+
const uint num_groups = (K + group_size - 1) / group_size;
4848
uint n = thread_index.x; // 0..N/4-1
4949
uint m = thread_index.z; // 0..M
5050
n = n / threads_per_channel;
@@ -64,12 +64,18 @@ kernel void int3pack_mm(constant T *A [[buffer(0)]],
6464
// Find specific group to which channels handled by this thread
6565
// belong.
6666
uint k_block_index = k / group_size;
67-
uint scales_group_offset = (k_block_index * N + n);
67+
uint scales_group_offset = (n * num_groups + k_block_index);
6868

6969
vecT scales =
70-
(reinterpret_cast<constant vecT *>(scales_ptr + scales_group_offset))[0];
70+
vecT(scales_ptr[scales_group_offset],
71+
scales_ptr[scales_group_offset + num_groups],
72+
scales_ptr[scales_group_offset + 2 * num_groups],
73+
scales_ptr[scales_group_offset + 3 * num_groups]);
7174
vecT zeros =
72-
(reinterpret_cast<constant vecT *>(zeros_ptr + scales_group_offset))[0];
75+
vecT(zeros_ptr[scales_group_offset],
76+
zeros_ptr[scales_group_offset + num_groups],
77+
zeros_ptr[scales_group_offset + 2 * num_groups],
78+
zeros_ptr[scales_group_offset + 3 * num_groups]);
7379
float4 zeros_float = float4(zeros);
7480

7581
float4 a_val[2];

torchao/experimental/kernels/mps/metal/int4mm_opt.metal

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,11 @@ using namespace metal;
6464
@param [in] B is weight matrix of size M x K. Each byte contains 2 4-bit
6565
values, along K dim, packed together.
6666
@param [in] scales_ptr is scales ptr corresponding each
67-
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
67+
output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output
6868
channels.
6969
@param [in] zeros_ptr is zero points corresponding each
70-
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
70+
output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output
7171
channels.
72-
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output
7372
@param [out] output_data is output matrix of size M x N.
7473
@param [in] sizes array contains values of M, K and N.
7574
@param [in] thread_index is global thread id.
@@ -89,6 +88,7 @@ kernel void int4pack_mm(constant T *A [[buffer(0)]],
8988
constexpr uint k_pack_factor = 2;
9089
const uint K = sizes.y;
9190
const uint N = sizes.z;
91+
const uint num_groups = (K + group_size - 1) / group_size;
9292
uint n = thread_index.x; // 0..N/4-1
9393
uint m = thread_index.z; // 0..M
9494
n = n / threads_per_channel;
@@ -113,13 +113,19 @@ kernel void int4pack_mm(constant T *A [[buffer(0)]],
113113
// Find specific group to which channels handled by this thread
114114
// belong.
115115
uint k_block_index = k / group_size;
116-
uint scales_group_offset = (k_block_index * N + n);
116+
uint scales_group_offset = (n * num_groups + k_block_index);
117117

118118
vecT scales =
119-
(reinterpret_cast<constant vecT *>(scales_ptr + scales_group_offset))[0];
119+
vecT(scales_ptr[scales_group_offset],
120+
scales_ptr[scales_group_offset + num_groups],
121+
scales_ptr[scales_group_offset + 2 * num_groups],
122+
scales_ptr[scales_group_offset + 3 * num_groups]);
120123
// Adding zero point results in 10% perf penalty.
121124
vecT zeros =
122-
(reinterpret_cast<constant vecT *>(zeros_ptr + scales_group_offset))[0];
125+
vecT(zeros_ptr[scales_group_offset],
126+
zeros_ptr[scales_group_offset + num_groups],
127+
zeros_ptr[scales_group_offset + 2 * num_groups],
128+
zeros_ptr[scales_group_offset + 3 * num_groups]);
123129
float4 zeros_float = float4(zeros);
124130

125131
float4 a_val = float4(A_ptr[k / 4]);

torchao/experimental/kernels/mps/metal/int5mm.metal

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ using namespace metal;
1111
*
1212
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
1313
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8)
14-
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
15-
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
14+
* @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups
15+
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups
1616
* @param[outputData] M x N output tensor of floating point dtype (same as input)
1717
* @param[sizes] The sizes involved in the order: M, K, N
1818
*
@@ -29,6 +29,7 @@ kernel void int5pack_mm(
2929
uint2 thread_index [[thread_position_in_grid]]) {
3030
const uint K = sizes.y;
3131
const uint N = sizes.z;
32+
const uint num_groups = (K + groupSize - 1) / groupSize;
3233
const uint m = thread_index.y; // 0..M-1
3334
const uint n = thread_index.x; // 0..N-1
3435
const uint32_t k_block = (K + groupSize - 1) / groupSize;
@@ -38,8 +39,8 @@ kernel void int5pack_mm(
3839
float rc = 0.0;
3940
uint k = 0;
4041
for (uint32_t kb = 0; kb < k_block ; kb ++) {
41-
const float scale = float(scales[kb * N + n]);
42-
const float zero = float(zeros[kb * N + n]);
42+
const float scale = float(scales[n * num_groups + kb]);
43+
const float zero = float(zeros[n * num_groups + kb]);
4344
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
4445
const auto a_val0 = float(A_ptr[k + 0]);
4546
const auto a_val1 = float(A_ptr[k + 1]);
@@ -56,15 +57,14 @@ kernel void int5pack_mm(
5657
uchar b3 = B_ptr[5 * (k / 8) + 3];
5758
uchar b4 = B_ptr[5 * (k / 8) + 4];
5859

59-
uchar w_val0 = ((b0 & 1) << 4) | (b1 & 15);
60-
uchar w_val1 = ((b0 & 2) << 3) | ((b1 & 240) >> 4);
61-
uchar w_val2 = ((b0 & 4) << 2) | (b2 & 15);
62-
uchar w_val3 = ((b0 & 8) << 1) | ((b2 & 240) >> 4);
63-
64-
uchar w_val4 = ((b0 & 16)) | (b3 & 15);
65-
uchar w_val5 = ((b0 & 32) >> 1) | ((b3 & 240) >> 4);
66-
uchar w_val6 = ((b0 & 64) >> 2) | (b4 & 15);
67-
uchar w_val7 = ((b0 & 128) >> 3) | ((b4 & 240) >> 4);
60+
uchar w_val0 = (b0 & 0x1f);
61+
uchar w_val1 = ((b0 & 0xe0) >> 5) | ((b1 & 0x03) << 3);
62+
uchar w_val2 = ((b1 & 0x7c) >> 2);
63+
uchar w_val3 = ((b1 & 0x80) >> 7) | ((b2 & 0x0f) << 1);
64+
uchar w_val4 = ((b2 & 0xf0) >> 4) | ((b3 & 0x01) << 4);
65+
uchar w_val5 = ((b3 & 0x3e) >> 1);
66+
uchar w_val6 = ((b3 & 0xc0) >> 6) | ((b4 & 0x07) << 2);
67+
uchar w_val7 = ((b4 & 0xf8) >> 3);
6868

6969
rc += a_val0 * (scale * float(w_val0) + zero);
7070
rc += a_val1 * (scale * float(w_val1) + zero);

torchao/experimental/kernels/mps/metal/int6mm.metal

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ using namespace metal;
1111
*
1212
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
1313
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (6 * K / 8)
14-
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
15-
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
14+
* @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups
15+
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups
1616
* @param[outputData] M x N output tensor of floating point dtype (same as input)
1717
* @param[sizes] The sizes involved in the order: M, K, N
1818
*
@@ -29,6 +29,7 @@ kernel void int6pack_mm(
2929
uint2 thread_index [[thread_position_in_grid]]) {
3030
const uint K = sizes.y;
3131
const uint N = sizes.z;
32+
const uint num_groups = (K + groupSize - 1) / groupSize;
3233
const uint m = thread_index.y; // 0..M-1
3334
const uint n = thread_index.x; // 0..N-1
3435
const uint32_t k_block = (K + groupSize - 1) / groupSize;
@@ -38,8 +39,8 @@ kernel void int6pack_mm(
3839
float rc = 0.0;
3940
uint k = 0;
4041
for (uint32_t kb = 0; kb < k_block ; kb ++) {
41-
const float scale = float(scales[kb * N + n]);
42-
const float zero = float(zeros[kb * N + n]);
42+
const float scale = float(scales[n * num_groups + kb]);
43+
const float zero = float(zeros[n * num_groups + kb]);
4344
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
4445
const auto a_val0 = float(A_ptr[k + 0]);
4546
const auto a_val1 = float(A_ptr[k + 1]);
@@ -59,15 +60,15 @@ kernel void int6pack_mm(
5960
uchar b4 = B_ptr[3 * (k / 4) + 4];
6061
uchar b5 = B_ptr[3 * (k / 4) + 5];
6162

62-
uchar w_val0 = ((b0 & 3) << 4) | (b1 & 15);
63-
uchar w_val1 = ((b0 & 12) << 2) | ((b1 & 240) >> 4);
64-
uchar w_val2 = ((b0 & 48)) | (b2 & 15);
65-
uchar w_val3 = ((b0 & 192) >> 2) | ((b2 & 240) >> 4);
63+
uchar w_val0 = (b0 & 0x3f);
64+
uchar w_val1 = ((b0 & 0xc0) >> 6) | ((b1 & 0x0f) << 2);
65+
uchar w_val2 = ((b1 & 0xf0) >> 4) | ((b2 & 0x03) << 4);
66+
uchar w_val3 = (b2 & 0xfc) >> 2;
6667

67-
uchar w_val4 = ((b3 & 3) << 4) | (b4 & 15);
68-
uchar w_val5 = ((b3 & 12) << 2) | ((b4 & 240) >> 4);
69-
uchar w_val6 = ((b3 & 48)) | (b5 & 15);
70-
uchar w_val7 = ((b3 & 192) >> 2) | ((b5 & 240) >> 4);
68+
uchar w_val4 = (b3 & 0x3f);
69+
uchar w_val5 = ((b3 & 0xc0) >> 6) | ((b4 & 0x0f) << 2);
70+
uchar w_val6 = ((b4 & 0xf0) >> 4) | ((b5 & 0x03) << 4);
71+
uchar w_val7 = (b5 & 0xfc) >> 2;
7172

7273
rc += a_val0 * (scale * float(w_val0) + zero);
7374
rc += a_val1 * (scale * float(w_val1) + zero);

torchao/experimental/kernels/mps/metal/int7mm.metal

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ using namespace metal;
1111
*
1212
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
1313
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (7 * K / 8)
14-
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
15-
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
14+
* @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups
15+
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups
1616
* @param[outputData] M x N output tensor of floating point dtype (same as input)
1717
* @param[sizes] The sizes involved in the order: M, K, N
1818
*
@@ -29,6 +29,7 @@ kernel void int7pack_mm(
2929
uint2 thread_index [[thread_position_in_grid]]) {
3030
const uint K = sizes.y;
3131
const uint N = sizes.z;
32+
const uint num_groups = (K + groupSize - 1) / groupSize;
3233
const uint m = thread_index.y; // 0..M-1
3334
const uint n = thread_index.x; // 0..N-1
3435
const uint32_t k_block = (K + groupSize - 1) / groupSize;
@@ -38,8 +39,8 @@ kernel void int7pack_mm(
3839
float rc = 0.0;
3940
uint k = 0;
4041
for (uint32_t kb = 0; kb < k_block ; kb ++) {
41-
const float scale = float(scales[kb * N + n]);
42-
const float zero = float(zeros[kb * N + n]);
42+
const float scale = float(scales[n * num_groups + kb]);
43+
const float zero = float(zeros[n * num_groups + kb]);
4344
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
4445
const auto a_val0 = float(A_ptr[k + 0]);
4546
const auto a_val1 = float(A_ptr[k + 1]);
@@ -58,15 +59,14 @@ kernel void int7pack_mm(
5859
uchar b5 = B_ptr[7 * (k / 8) + 5];
5960
uchar b6 = B_ptr[7 * (k / 8) + 6];
6061

61-
uchar w_val0 = b0 & 127;
62-
uchar w_val1 = b1 & 127;
63-
uchar w_val2 = b2 & 127;
64-
uchar w_val3 = b3 & 127;
65-
uchar w_val4 = b4 & 127;
66-
uchar w_val5 = b5 & 127;
67-
uchar w_val6 = b6 & 127;
68-
uchar w_val7 = ((b0 & 128) >> 7) | ((b1 & 128) >> 6) | ((b2 & 128) >> 5) | ((b3 & 128) >> 4)
69-
| ((b4 & 128) >> 3) | ((b5 & 128) >> 2) | ((b6 & 128) >> 1);
62+
uchar w_val0 = (b0 & 0x7f);
63+
uchar w_val1 = (b0 >> 7) | ((b1 & 0x3f) << 1);
64+
uchar w_val2 = (b1 >> 6) | ((b2 & 0x1f) << 2);
65+
uchar w_val3 = (b2 >> 5) | ((b3 & 0x0f) << 3);
66+
uchar w_val4 = (b3 >> 4) | ((b4 & 0x07) << 4);
67+
uchar w_val5 = (b4 >> 3) | ((b5 & 0x03) << 5);
68+
uchar w_val6 = (b5 >> 2) | ((b6 & 0x01) << 6);
69+
uchar w_val7 = (b6 >> 1);
7070

7171
rc += a_val0 * (scale * float(w_val0) + zero);
7272
rc += a_val1 * (scale * float(w_val1) + zero);

0 commit comments

Comments
 (0)