|
| 1 | +#include <metal_simdgroup> |
| 2 | +#include <metal_stdlib> |
| 3 | +using namespace metal; |
| 4 | + |
| 5 | +/* |
| 6 | + This code takes heavy inspiration from MLX: |
| 7 | + https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/quantized.h |
| 8 | + Specifically: |
| 9 | + - Multiplying activation by inverse scaling factor to reduce compute |
| 10 | + boundedness |
| 11 | + - Handling zero point by accumulating act in separate sum term. Needed with |
| 12 | + optimization done above. MLX MIT License: |
| 13 | + https://github.com/ml-explore/mlx/blob/main/LICENSE |
| 14 | +*/ |
| 15 | + |
| 16 | +/* |
| 17 | + @brief This shader implements 2-bit matrix-vector multiplication where A |
| 18 | + matrix is fp16, bfloat or float and B matrix is a 2-bit groupwise-quantized weight |
| 19 | + matrix. |
| 20 | + @param [in] A is activation matrix of size M x K. |
| 21 | + @param [in] B is weight matrix of size M x K. Each byte contains 4 2-bit |
| 22 | + values, along K dim, packed together. |
| 23 | + @param [in] scales_ptr is scales ptr corresponding each |
| 24 | + output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output |
| 25 | + channels. |
| 26 | + @param [in] zeros_ptr is zero points corresponding each |
| 27 | + output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output |
| 28 | + channels. |
| 29 | + output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output |
| 30 | + @param [out] output_data is output matrix of size M x N. |
| 31 | + @param [in] sizes array contains values of M, K and N. |
| 32 | + @param [in] thread_index is global thread id. |
| 33 | + @param [in] tid_in_simdgruop is thread id in simdgroup. e.g. in simdgroup of size 32 it can be in [0-31]. |
| 34 | +*/ |
| 35 | +template <typename T, unsigned group_size> |
| 36 | +kernel void int2pack_mm(constant T *A [[buffer(0)]], |
| 37 | + constant uchar *B [[buffer(1)]], |
| 38 | + constant T *scales_ptr [[buffer(2)]], |
| 39 | + constant T *zeros_ptr [[buffer(3)]], |
| 40 | + device T *output_data [[buffer(4)]], |
| 41 | + constant uint3 &sizes [[buffer(5)]], // M, K, N |
| 42 | + uint3 thread_index [[thread_position_in_grid]], |
| 43 | + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) { |
| 44 | + constexpr uint threads_per_channel = 32; |
| 45 | + constexpr uint ks_per_thread = 4; |
| 46 | + constexpr uint k_pack_factor = 4; |
| 47 | + const uint K = sizes.y; |
| 48 | + const uint N = sizes.z; |
| 49 | + uint n = thread_index.x; // 0..N/4-1 |
| 50 | + uint m = thread_index.z; // 0..M |
| 51 | + n = n / threads_per_channel; |
| 52 | + n = n * 4; |
| 53 | + // This is starting k for each thread. In the example above, for thread 1 this |
| 54 | + // value will be 4. |
| 55 | + uint k = (tid_in_simdgroup % threads_per_channel) * ks_per_thread; |
| 56 | + constexpr int k_jump = threads_per_channel * ks_per_thread; |
| 57 | + |
| 58 | + using vecT = typename Vec4Type<T>::type; |
| 59 | + constant vecT *A_ptr = reinterpret_cast<constant vecT *>(A + m * K); |
| 60 | + constant uchar *B_ptr = B + ((n * K) / k_pack_factor); |
| 61 | + |
| 62 | + thread float4 result = float4(0.0); |
| 63 | + // We multipy group of 4 channels with these scales. |
| 64 | + // Because corresponding values from weight matrix are effectively left |
| 65 | + // shifted. This is to avoid doing right shift on those values which ends up |
| 66 | + // affecting performance. This is the trick applied in MLX kernels. |
| 67 | + float4 act_div_scales = {1.f, 1 / 4.f, 1 / 16.f, 1 / 64.f}; |
| 68 | + |
| 69 | + for (; k < K; k += k_jump) { |
| 70 | + // Find specific group to which channels handled by this thread |
| 71 | + // belong. |
| 72 | + uint k_block_index = k / group_size; |
| 73 | + uint scales_group_offset = (k_block_index * N + n); |
| 74 | + |
| 75 | + vecT scales = |
| 76 | + (reinterpret_cast<constant vecT *>(scales_ptr + scales_group_offset))[0]; |
| 77 | + // Adding zero point results in 10% perf penalty. |
| 78 | + vecT zeros = |
| 79 | + (reinterpret_cast<constant vecT *>(zeros_ptr + scales_group_offset))[0]; |
| 80 | + float4 zeros_float = float4(zeros); |
| 81 | + |
| 82 | + float4 a_val = float4(A_ptr[k / 4]); |
| 83 | + // We are gonna skip right-shifts of the weights and hence divide by corresponding factor. |
| 84 | + float4 a_vec = a_val * act_div_scales; |
| 85 | + float a_val_sum = a_val[0] + a_val[1] + a_val[2] + a_val[3]; |
| 86 | + |
| 87 | + float4x4 b_mat; |
| 88 | + ushort b_val0 = (B_ptr + (k + 0 * K) / k_pack_factor)[0]; |
| 89 | + ushort b_val1 = (B_ptr + (k + 1 * K) / k_pack_factor)[0]; |
| 90 | + ushort b_val2 = (B_ptr + (k + 2 * K) / k_pack_factor)[0]; |
| 91 | + ushort b_val3 = (B_ptr + (k + 3 * K) / k_pack_factor)[0]; |
| 92 | + b_mat[0] = scales[0] * float4(float(b_val0 & 0x03), float(b_val0 & 0x0c), |
| 93 | + float(b_val0 & 0x30), float(b_val0 & 0xc0)); |
| 94 | + b_mat[1] = scales[1] * float4(float(b_val1 & 0x03), float(b_val1 & 0x0c), |
| 95 | + float(b_val1 & 0x30), float(b_val1 & 0xc0)); |
| 96 | + b_mat[2] = scales[2] * float4(float(b_val2 & 0x03), float(b_val2 & 0x0c), |
| 97 | + float(b_val2 & 0x30), float(b_val2 & 0xc0)); |
| 98 | + b_mat[3] = scales[3] * float4(float(b_val3 & 0x03), float(b_val3 & 0x0c), |
| 99 | + float(b_val3 & 0x30), float(b_val3 & 0xc0)); |
| 100 | + |
| 101 | + result += a_vec * b_mat; |
| 102 | + result += a_val_sum * zeros_float; |
| 103 | + } |
| 104 | + result += simd_shuffle_down(result, 1); |
| 105 | + result += simd_shuffle_down(result, 2); |
| 106 | + result += simd_shuffle_down(result, 4); |
| 107 | + result += simd_shuffle_down(result, 8); |
| 108 | + result += simd_shuffle_down(result, 16); |
| 109 | + if (tid_in_simdgroup % threads_per_channel == 0) { |
| 110 | + reinterpret_cast<device vecT *>(output_data + m * N)[n / 4] = vecT(result); |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +#define INSTANTIATE_INT2MM(DTYPE, GSIZE) \ |
| 115 | + template [[host_name("int2pack_mm_" #GSIZE "_" #DTYPE)]] kernel void \ |
| 116 | + int2pack_mm<DTYPE, GSIZE>( \ |
| 117 | + constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]], \ |
| 118 | + constant DTYPE * scales_ptr [[buffer(2)]], \ |
| 119 | + constant DTYPE * zeros_ptr [[buffer(3)]], \ |
| 120 | + device DTYPE * output_data [[buffer(4)]], \ |
| 121 | + constant uint3 & sizes [[buffer(5)]], \ |
| 122 | + uint3 thread_index [[thread_position_in_grid]], \ |
| 123 | + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) |
| 124 | + |
| 125 | +INSTANTIATE_INT2MM(float, 32); |
| 126 | +INSTANTIATE_INT2MM(half, 32); |
| 127 | +INSTANTIATE_INT2MM(float, 64); |
| 128 | +INSTANTIATE_INT2MM(half, 64); |
| 129 | +INSTANTIATE_INT2MM(float, 128); |
| 130 | +INSTANTIATE_INT2MM(half, 128); |
| 131 | +INSTANTIATE_INT2MM(float, 256); |
| 132 | +INSTANTIATE_INT2MM(half, 256); |
| 133 | +#if __METAL_VERSION__ >= 310 |
| 134 | +INSTANTIATE_INT2MM(bfloat, 32); |
| 135 | +INSTANTIATE_INT2MM(bfloat, 64); |
| 136 | +INSTANTIATE_INT2MM(bfloat, 128); |
| 137 | +INSTANTIATE_INT2MM(bfloat, 256); |
| 138 | +#endif |
0 commit comments