Skip to content

Commit 2e032c6

Browse files
metal lowbit kernels: optimized 2-bit, 3-bit and 4-bit shaders (#1422)
1 parent f52d3ab commit 2e032c6

17 files changed

+599
-188
lines changed

torchao/experimental/kernels/mps/metal.yaml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
- func: Vec4Type
2+
file: common.metal
3+
14
- func: int1mm
2-
file: divbit.metal
5+
file: int1mm.metal
36

47
- func: int2mm
5-
file: divbit.metal
8+
file: int2mm_opt.metal
69

710
- func: int3mm
8-
file: int3mm.metal
11+
file: int3mm_opt.metal
912

1013
- func: int4mm
11-
file: divbit.metal
14+
file: int4mm_opt.metal
1215

1316
- func: int5mm
1417
file: int5mm.metal
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
template <typename T> struct Vec4Type {};
2+
3+
template <> struct Vec4Type<float> {
4+
using type = float4;
5+
};
6+
7+
template <> struct Vec4Type<half> {
8+
using type = half4;
9+
};
10+
11+
#if __METAL_VERSION__ >= 310
12+
template <> struct Vec4Type<bfloat> {
13+
using type = bfloat4;
14+
};
15+
#endif

torchao/experimental/kernels/mps/metal/divbit.metal

Lines changed: 0 additions & 109 deletions
This file was deleted.

torchao/experimental/kernels/mps/metal/int3mm.metal renamed to torchao/experimental/kernels/mps/metal/int1mm.metal

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
using namespace metal;
33

44
/**
5-
* 3-Bit Quantized Linear.
5+
* 1-Bit Quantized Linear.
66
*
7-
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
8-
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8)
7+
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
8+
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (K / 8)
99
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
1010
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
1111
* @param[outputData] M x N output tensor of floating point dtype (same as input)
@@ -14,7 +14,7 @@ using namespace metal;
1414
* Dispatched threads: N x M x 1
1515
*/
1616
template<typename T, unsigned groupSize>
17-
kernel void int3pack_mm(
17+
kernel void int1pack_mm(
1818
constant T * A [[buffer(0)]],
1919
constant uchar * B [[buffer(1)]],
2020
constant T * scales [[buffer(2)]],
@@ -28,7 +28,7 @@ kernel void int3pack_mm(
2828
const uint n = thread_index.x; // 0..N-1
2929
const uint32_t k_block = (K + groupSize - 1) / groupSize;
3030
constant T *A_ptr = A + m * K;
31-
constant uchar *B_ptr = B + n * 3 * K / 8;
31+
constant uchar *B_ptr = B + n * K / 8;
3232

3333
float rc = 0.0;
3434
uint k = 0;
@@ -45,19 +45,16 @@ kernel void int3pack_mm(
4545
const auto a_val6 = float(A_ptr[k + 6]);
4646
const auto a_val7 = float(A_ptr[k + 7]);
4747

48-
uchar b0 = B_ptr[3 * (k / 8) + 0];
49-
uchar b1 = B_ptr[3 * (k / 8) + 1];
50-
uchar b2 = B_ptr[3 * (k / 8) + 2];
48+
uchar b0 = B_ptr[(k / 8)];
5149

52-
uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3);
53-
uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2);
54-
uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4);
55-
uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6);
56-
57-
uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3);
58-
uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2);
59-
uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4);
60-
uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6);
50+
uchar w_val0 = b0 & 0x01;
51+
uchar w_val1 = (b0 & 0x02) >> 1;
52+
uchar w_val2 = (b0 & 0x04) >> 2;
53+
uchar w_val3 = (b0 & 0x08) >> 3;
54+
uchar w_val4 = (b0 & 0x10) >> 4;
55+
uchar w_val5 = (b0 & 0x20) >> 5;
56+
uchar w_val6 = (b0 & 0x40) >> 6;
57+
uchar w_val7 = (b0 & 0x80) >> 7;
6158

6259
rc += a_val0 * (scale * float(w_val0) + zero);
6360
rc += a_val1 * (scale * float(w_val1) + zero);
@@ -72,10 +69,10 @@ kernel void int3pack_mm(
7269
outputData[m * N + n] = T(rc);
7370
}
7471

75-
#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \
72+
#define INSTANTIATE_INT1MM(DTYPE, GSIZE) \
7673
template \
77-
[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \
78-
kernel void int3pack_mm<DTYPE, GSIZE>( \
74+
[[host_name("int1pack_mm_" #GSIZE "_" #DTYPE)]] \
75+
kernel void int1pack_mm<DTYPE, GSIZE>( \
7976
constant DTYPE * A [[buffer(0)]], \
8077
constant uchar * B [[buffer(1)]], \
8178
constant DTYPE * scales [[buffer(2)]], \
@@ -84,17 +81,17 @@ kernel void int3pack_mm<DTYPE, GSIZE>( \
8481
constant uint3 & sizes [[buffer(5)]], \
8582
uint2 thread_index [[thread_position_in_grid]])
8683

87-
INSTANTIATE_INT3MM(float, 32);
88-
INSTANTIATE_INT3MM(half, 32);
89-
INSTANTIATE_INT3MM(float, 64);
90-
INSTANTIATE_INT3MM(half, 64);
91-
INSTANTIATE_INT3MM(float, 128);
92-
INSTANTIATE_INT3MM(half, 128);
93-
INSTANTIATE_INT3MM(float, 256);
94-
INSTANTIATE_INT3MM(half, 256);
84+
INSTANTIATE_INT1MM(float, 32);
85+
INSTANTIATE_INT1MM(half, 32);
86+
INSTANTIATE_INT1MM(float, 64);
87+
INSTANTIATE_INT1MM(half, 64);
88+
INSTANTIATE_INT1MM(float, 128);
89+
INSTANTIATE_INT1MM(half, 128);
90+
INSTANTIATE_INT1MM(float, 256);
91+
INSTANTIATE_INT1MM(half, 256);
9592
#if __METAL_VERSION__ >= 310
96-
INSTANTIATE_INT3MM(bfloat, 32);
97-
INSTANTIATE_INT3MM(bfloat, 64);
98-
INSTANTIATE_INT3MM(bfloat, 128);
99-
INSTANTIATE_INT3MM(bfloat, 256);
93+
INSTANTIATE_INT1MM(bfloat, 32);
94+
INSTANTIATE_INT1MM(bfloat, 64);
95+
INSTANTIATE_INT1MM(bfloat, 128);
96+
INSTANTIATE_INT1MM(bfloat, 256);
10097
#endif
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#include <metal_simdgroup>
2+
#include <metal_stdlib>
3+
using namespace metal;
4+
5+
/*
6+
This code takes heavy inspiration from MLX:
7+
https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/quantized.h
8+
Specifically:
9+
- Multiplying activation by inverse scaling factor to reduce compute
10+
boundedness
11+
- Handling zero point by accumulating act in separate sum term. Needed with
12+
optimization done above. MLX MIT License:
13+
https://github.com/ml-explore/mlx/blob/main/LICENSE
14+
*/
15+
16+
/*
17+
@brief This shader implements 2-bit matrix-vector multiplication where A
18+
matrix is fp16, bfloat or float and B matrix is a 2-bit groupwise-quantized weight
19+
matrix.
20+
@param [in] A is activation matrix of size M x K.
21+
@param [in] B is weight matrix of size M x K. Each byte contains 4 2-bit
22+
values, along K dim, packed together.
23+
@param [in] scales_ptr is scales ptr corresponding each
24+
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
25+
channels.
26+
@param [in] zeros_ptr is zero points corresponding each
27+
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
28+
channels.
29+
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output
30+
@param [out] output_data is output matrix of size M x N.
31+
@param [in] sizes array contains values of M, K and N.
32+
@param [in] thread_index is global thread id.
33+
@param [in] tid_in_simdgruop is thread id in simdgroup. e.g. in simdgroup of size 32 it can be in [0-31].
34+
*/
35+
template <typename T, unsigned group_size>
36+
kernel void int2pack_mm(constant T *A [[buffer(0)]],
37+
constant uchar *B [[buffer(1)]],
38+
constant T *scales_ptr [[buffer(2)]],
39+
constant T *zeros_ptr [[buffer(3)]],
40+
device T *output_data [[buffer(4)]],
41+
constant uint3 &sizes [[buffer(5)]], // M, K, N
42+
uint3 thread_index [[thread_position_in_grid]],
43+
uint tid_in_simdgroup [[thread_index_in_simdgroup]]) {
44+
constexpr uint threads_per_channel = 32;
45+
constexpr uint ks_per_thread = 4;
46+
constexpr uint k_pack_factor = 4;
47+
const uint K = sizes.y;
48+
const uint N = sizes.z;
49+
uint n = thread_index.x; // 0..N/4-1
50+
uint m = thread_index.z; // 0..M
51+
n = n / threads_per_channel;
52+
n = n * 4;
53+
// This is starting k for each thread. In the example above, for thread 1 this
54+
// value will be 4.
55+
uint k = (tid_in_simdgroup % threads_per_channel) * ks_per_thread;
56+
constexpr int k_jump = threads_per_channel * ks_per_thread;
57+
58+
using vecT = typename Vec4Type<T>::type;
59+
constant vecT *A_ptr = reinterpret_cast<constant vecT *>(A + m * K);
60+
constant uchar *B_ptr = B + ((n * K) / k_pack_factor);
61+
62+
thread float4 result = float4(0.0);
63+
// We multipy group of 4 channels with these scales.
64+
// Because corresponding values from weight matrix are effectively left
65+
// shifted. This is to avoid doing right shift on those values which ends up
66+
// affecting performance. This is the trick applied in MLX kernels.
67+
float4 act_div_scales = {1.f, 1 / 4.f, 1 / 16.f, 1 / 64.f};
68+
69+
for (; k < K; k += k_jump) {
70+
// Find specific group to which channels handled by this thread
71+
// belong.
72+
uint k_block_index = k / group_size;
73+
uint scales_group_offset = (k_block_index * N + n);
74+
75+
vecT scales =
76+
(reinterpret_cast<constant vecT *>(scales_ptr + scales_group_offset))[0];
77+
// Adding zero point results in 10% perf penalty.
78+
vecT zeros =
79+
(reinterpret_cast<constant vecT *>(zeros_ptr + scales_group_offset))[0];
80+
float4 zeros_float = float4(zeros);
81+
82+
float4 a_val = float4(A_ptr[k / 4]);
83+
// We are gonna skip right-shifts of the weights and hence divide by corresponding factor.
84+
float4 a_vec = a_val * act_div_scales;
85+
float a_val_sum = a_val[0] + a_val[1] + a_val[2] + a_val[3];
86+
87+
float4x4 b_mat;
88+
ushort b_val0 = (B_ptr + (k + 0 * K) / k_pack_factor)[0];
89+
ushort b_val1 = (B_ptr + (k + 1 * K) / k_pack_factor)[0];
90+
ushort b_val2 = (B_ptr + (k + 2 * K) / k_pack_factor)[0];
91+
ushort b_val3 = (B_ptr + (k + 3 * K) / k_pack_factor)[0];
92+
b_mat[0] = scales[0] * float4(float(b_val0 & 0x03), float(b_val0 & 0x0c),
93+
float(b_val0 & 0x30), float(b_val0 & 0xc0));
94+
b_mat[1] = scales[1] * float4(float(b_val1 & 0x03), float(b_val1 & 0x0c),
95+
float(b_val1 & 0x30), float(b_val1 & 0xc0));
96+
b_mat[2] = scales[2] * float4(float(b_val2 & 0x03), float(b_val2 & 0x0c),
97+
float(b_val2 & 0x30), float(b_val2 & 0xc0));
98+
b_mat[3] = scales[3] * float4(float(b_val3 & 0x03), float(b_val3 & 0x0c),
99+
float(b_val3 & 0x30), float(b_val3 & 0xc0));
100+
101+
result += a_vec * b_mat;
102+
result += a_val_sum * zeros_float;
103+
}
104+
result += simd_shuffle_down(result, 1);
105+
result += simd_shuffle_down(result, 2);
106+
result += simd_shuffle_down(result, 4);
107+
result += simd_shuffle_down(result, 8);
108+
result += simd_shuffle_down(result, 16);
109+
if (tid_in_simdgroup % threads_per_channel == 0) {
110+
reinterpret_cast<device vecT *>(output_data + m * N)[n / 4] = vecT(result);
111+
}
112+
}
113+
114+
#define INSTANTIATE_INT2MM(DTYPE, GSIZE) \
115+
template [[host_name("int2pack_mm_" #GSIZE "_" #DTYPE)]] kernel void \
116+
int2pack_mm<DTYPE, GSIZE>( \
117+
constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]], \
118+
constant DTYPE * scales_ptr [[buffer(2)]], \
119+
constant DTYPE * zeros_ptr [[buffer(3)]], \
120+
device DTYPE * output_data [[buffer(4)]], \
121+
constant uint3 & sizes [[buffer(5)]], \
122+
uint3 thread_index [[thread_position_in_grid]], \
123+
uint tid_in_simdgroup [[thread_index_in_simdgroup]])
124+
125+
INSTANTIATE_INT2MM(float, 32);
126+
INSTANTIATE_INT2MM(half, 32);
127+
INSTANTIATE_INT2MM(float, 64);
128+
INSTANTIATE_INT2MM(half, 64);
129+
INSTANTIATE_INT2MM(float, 128);
130+
INSTANTIATE_INT2MM(half, 128);
131+
INSTANTIATE_INT2MM(float, 256);
132+
INSTANTIATE_INT2MM(half, 256);
133+
#if __METAL_VERSION__ >= 310
134+
INSTANTIATE_INT2MM(bfloat, 32);
135+
INSTANTIATE_INT2MM(bfloat, 64);
136+
INSTANTIATE_INT2MM(bfloat, 128);
137+
INSTANTIATE_INT2MM(bfloat, 256);
138+
#endif

0 commit comments

Comments
 (0)