Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchao/experimental/kernels/mps/metal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@

- func: int7mm
file: int7mm.metal

- func: qmv_fast
file: qmv_fast.metal
9 changes: 5 additions & 4 deletions torchao/experimental/kernels/mps/metal/int1mm.metal
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ using namespace metal;
*
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (K / 8)
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
* @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
Expand All @@ -29,6 +29,7 @@ kernel void int1pack_mm(
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint num_groups = (K + groupSize - 1) / groupSize;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
Expand All @@ -38,8 +39,8 @@ kernel void int1pack_mm(
float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scales[kb * N + n]);
const float zero = float(zeros[kb * N + n]);
const float scale = float(scales[n * num_groups + kb]);
const float zero = float(zeros[n * num_groups + kb]);
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
Expand Down
19 changes: 12 additions & 7 deletions torchao/experimental/kernels/mps/metal/int2mm_opt.metal
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ using namespace metal;
@param [in] B is weight matrix of size M x K. Each byte contains 4 2-bit
values, along K dim, packed together.
@param [in] scales_ptr is scales ptr corresponding each
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output
channels.
@param [in] zeros_ptr is zero points corresponding each
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output
channels.
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output
@param [out] output_data is output matrix of size M x N.
@param [in] sizes array contains values of M, K and N.
@param [in] thread_index is global thread id.
Expand All @@ -51,6 +50,7 @@ kernel void int2pack_mm(constant T *A [[buffer(0)]],
constexpr uint k_pack_factor = 4;
const uint K = sizes.y;
const uint N = sizes.z;
const uint num_groups = (K + group_size - 1) / group_size;
uint n = thread_index.x; // 0..N/4-1
uint m = thread_index.z; // 0..M
n = n / threads_per_channel;
Expand All @@ -75,13 +75,18 @@ kernel void int2pack_mm(constant T *A [[buffer(0)]],
// Find specific group to which channels handled by this thread
// belong.
uint k_block_index = k / group_size;
uint scales_group_offset = (k_block_index * N + n);
uint scales_group_offset = (n * num_groups + k_block_index);

vecT scales =
(reinterpret_cast<constant vecT *>(scales_ptr + scales_group_offset))[0];
// Adding zero point results in 10% perf penalty.
vecT(scales_ptr[scales_group_offset],
scales_ptr[scales_group_offset + num_groups],
scales_ptr[scales_group_offset + 2 * num_groups],
scales_ptr[scales_group_offset + 3 * num_groups]);
vecT zeros =
(reinterpret_cast<constant vecT *>(zeros_ptr + scales_group_offset))[0];
vecT(zeros_ptr[scales_group_offset],
zeros_ptr[scales_group_offset + num_groups],
zeros_ptr[scales_group_offset + 2 * num_groups],
zeros_ptr[scales_group_offset + 3 * num_groups]);
float4 zeros_float = float4(zeros);

float4 a_val = float4(A_ptr[k / 4]);
Expand Down
34 changes: 20 additions & 14 deletions torchao/experimental/kernels/mps/metal/int3mm_opt.metal
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,23 @@
using namespace metal;

inline void unpack_3bit(const uchar3 b, thread float* w) {
w[0] = float(((b[0] & 1) << 2) | (b[1] & 3));
w[1] = float(((b[0] & 2) << 1) | ((b[1] & 12) >> 2));
w[2] = float((b[0] & 4) | ((b[1] & 48) >> 4));
w[3] = float(((b[0] & 8) >> 1) | ((b[1] & 192) >> 6));

w[4] = float(((b[0] & 16) >> 2) | (b[2] & 3));
w[5] = float(((b[0] & 32) >> 3) | ((b[2] & 12) >> 2));
w[6] = float(((b[0] & 64) >> 4) | ((b[2] & 48) >> 4));
w[7] = float(((b[0] & 128) >> 5) | ((b[2] & 192) >> 6));
w[0] = float(b[0] & 0x07);
w[1] = float((b[0] & 0x38) >> 3);
w[2] = float(((b[0] & 0xc0) >> 6) | ((b[1] & 0x01) << 2));
w[3] = float((b[1] & 0x0e) >> 1);
w[4] = float((b[1] & 0x70) >> 4);
w[5] = float(((b[1] & 0x80) >> 7) | ((b[2] & 0x03) << 1));
w[6] = float((b[2] & 0x1c) >> 2);
w[7] = float((b[2] & 0xe0) >> 5);
Comment on lines +11 to +18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am definitely surprised that this is better

}

/**
* 3-Bit Quantized Linear.
*
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8)
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
* @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
Expand All @@ -45,6 +44,7 @@ kernel void int3pack_mm(constant T *A [[buffer(0)]],
constexpr uint k_pack_factor = 8;
const uint K = sizes.y;
const uint N = sizes.z;
const uint num_groups = (K + group_size - 1) / group_size;
uint n = thread_index.x; // 0..N/4-1
uint m = thread_index.z; // 0..M
n = n / threads_per_channel;
Expand All @@ -64,12 +64,18 @@ kernel void int3pack_mm(constant T *A [[buffer(0)]],
// Find specific group to which channels handled by this thread
// belong.
uint k_block_index = k / group_size;
uint scales_group_offset = (k_block_index * N + n);
uint scales_group_offset = (n * num_groups + k_block_index);

vecT scales =
(reinterpret_cast<constant vecT *>(scales_ptr + scales_group_offset))[0];
vecT(scales_ptr[scales_group_offset],
scales_ptr[scales_group_offset + num_groups],
scales_ptr[scales_group_offset + 2 * num_groups],
scales_ptr[scales_group_offset + 3 * num_groups]);
vecT zeros =
(reinterpret_cast<constant vecT *>(zeros_ptr + scales_group_offset))[0];
vecT(zeros_ptr[scales_group_offset],
zeros_ptr[scales_group_offset + num_groups],
zeros_ptr[scales_group_offset + 2 * num_groups],
zeros_ptr[scales_group_offset + 3 * num_groups]);
float4 zeros_float = float4(zeros);

float4 a_val[2];
Expand Down
18 changes: 12 additions & 6 deletions torchao/experimental/kernels/mps/metal/int4mm_opt.metal
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@ using namespace metal;
@param [in] B is weight matrix of size M x K. Each byte contains 2 4-bit
values, along K dim, packed together.
@param [in] scales_ptr is scales ptr corresponding each
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this work for gemm as well?

channels.
@param [in] zeros_ptr is zero points corresponding each
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output
channels.
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output
@param [out] output_data is output matrix of size M x N.
@param [in] sizes array contains values of M, K and N.
@param [in] thread_index is global thread id.
Expand All @@ -89,6 +88,7 @@ kernel void int4pack_mm(constant T *A [[buffer(0)]],
constexpr uint k_pack_factor = 2;
const uint K = sizes.y;
const uint N = sizes.z;
const uint num_groups = (K + group_size - 1) / group_size;
uint n = thread_index.x; // 0..N/4-1
uint m = thread_index.z; // 0..M
n = n / threads_per_channel;
Expand All @@ -113,13 +113,19 @@ kernel void int4pack_mm(constant T *A [[buffer(0)]],
// Find specific group to which channels handled by this thread
// belong.
uint k_block_index = k / group_size;
uint scales_group_offset = (k_block_index * N + n);
uint scales_group_offset = (n * num_groups + k_block_index);

vecT scales =
(reinterpret_cast<constant vecT *>(scales_ptr + scales_group_offset))[0];
vecT(scales_ptr[scales_group_offset],
scales_ptr[scales_group_offset + num_groups],
scales_ptr[scales_group_offset + 2 * num_groups],
scales_ptr[scales_group_offset + 3 * num_groups]);
// Adding zero point results in 10% perf penalty.
vecT zeros =
(reinterpret_cast<constant vecT *>(zeros_ptr + scales_group_offset))[0];
vecT(zeros_ptr[scales_group_offset],
zeros_ptr[scales_group_offset + num_groups],
zeros_ptr[scales_group_offset + 2 * num_groups],
zeros_ptr[scales_group_offset + 3 * num_groups]);
float4 zeros_float = float4(zeros);

float4 a_val = float4(A_ptr[k / 4]);
Expand Down
26 changes: 13 additions & 13 deletions torchao/experimental/kernels/mps/metal/int5mm.metal
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ using namespace metal;
*
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8)
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
* @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
Expand All @@ -29,6 +29,7 @@ kernel void int5pack_mm(
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint num_groups = (K + groupSize - 1) / groupSize;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
Expand All @@ -38,8 +39,8 @@ kernel void int5pack_mm(
float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scales[kb * N + n]);
const float zero = float(zeros[kb * N + n]);
const float scale = float(scales[n * num_groups + kb]);
const float zero = float(zeros[n * num_groups + kb]);
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
Expand All @@ -56,15 +57,14 @@ kernel void int5pack_mm(
uchar b3 = B_ptr[5 * (k / 8) + 3];
uchar b4 = B_ptr[5 * (k / 8) + 4];

uchar w_val0 = ((b0 & 1) << 4) | (b1 & 15);
uchar w_val1 = ((b0 & 2) << 3) | ((b1 & 240) >> 4);
uchar w_val2 = ((b0 & 4) << 2) | (b2 & 15);
uchar w_val3 = ((b0 & 8) << 1) | ((b2 & 240) >> 4);

uchar w_val4 = ((b0 & 16)) | (b3 & 15);
uchar w_val5 = ((b0 & 32) >> 1) | ((b3 & 240) >> 4);
uchar w_val6 = ((b0 & 64) >> 2) | (b4 & 15);
uchar w_val7 = ((b0 & 128) >> 3) | ((b4 & 240) >> 4);
uchar w_val0 = (b0 & 0x1f);
uchar w_val1 = ((b0 & 0xe0) >> 5) | ((b1 & 0x03) << 3);
uchar w_val2 = ((b1 & 0x7c) >> 2);
uchar w_val3 = ((b1 & 0x80) >> 7) | ((b2 & 0x0f) << 1);
uchar w_val4 = ((b2 & 0xf0) >> 4) | ((b3 & 0x01) << 4);
uchar w_val5 = ((b3 & 0x3e) >> 1);
uchar w_val6 = ((b3 & 0xc0) >> 6) | ((b4 & 0x07) << 2);
uchar w_val7 = ((b4 & 0xf8) >> 3);

rc += a_val0 * (scale * float(w_val0) + zero);
rc += a_val1 * (scale * float(w_val1) + zero);
Expand Down
25 changes: 13 additions & 12 deletions torchao/experimental/kernels/mps/metal/int6mm.metal
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ using namespace metal;
*
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (6 * K / 8)
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
* @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
Expand All @@ -29,6 +29,7 @@ kernel void int6pack_mm(
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint num_groups = (K + groupSize - 1) / groupSize;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
Expand All @@ -38,8 +39,8 @@ kernel void int6pack_mm(
float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scales[kb * N + n]);
const float zero = float(zeros[kb * N + n]);
const float scale = float(scales[n * num_groups + kb]);
const float zero = float(zeros[n * num_groups + kb]);
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
Expand All @@ -59,15 +60,15 @@ kernel void int6pack_mm(
uchar b4 = B_ptr[3 * (k / 4) + 4];
uchar b5 = B_ptr[3 * (k / 4) + 5];

uchar w_val0 = ((b0 & 3) << 4) | (b1 & 15);
uchar w_val1 = ((b0 & 12) << 2) | ((b1 & 240) >> 4);
uchar w_val2 = ((b0 & 48)) | (b2 & 15);
uchar w_val3 = ((b0 & 192) >> 2) | ((b2 & 240) >> 4);
uchar w_val0 = (b0 & 0x3f);
uchar w_val1 = ((b0 & 0xc0) >> 6) | ((b1 & 0x0f) << 2);
uchar w_val2 = ((b1 & 0xf0) >> 4) | ((b2 & 0x03) << 4);
uchar w_val3 = (b2 & 0xfc) >> 2;

uchar w_val4 = ((b3 & 3) << 4) | (b4 & 15);
uchar w_val5 = ((b3 & 12) << 2) | ((b4 & 240) >> 4);
uchar w_val6 = ((b3 & 48)) | (b5 & 15);
uchar w_val7 = ((b3 & 192) >> 2) | ((b5 & 240) >> 4);
uchar w_val4 = (b3 & 0x3f);
uchar w_val5 = ((b3 & 0xc0) >> 6) | ((b4 & 0x0f) << 2);
uchar w_val6 = ((b4 & 0xf0) >> 4) | ((b5 & 0x03) << 4);
uchar w_val7 = (b5 & 0xfc) >> 2;

rc += a_val0 * (scale * float(w_val0) + zero);
rc += a_val1 * (scale * float(w_val1) + zero);
Expand Down
26 changes: 13 additions & 13 deletions torchao/experimental/kernels/mps/metal/int7mm.metal
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ using namespace metal;
*
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (7 * K / 8)
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
* @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
Expand All @@ -29,6 +29,7 @@ kernel void int7pack_mm(
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint num_groups = (K + groupSize - 1) / groupSize;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
Expand All @@ -38,8 +39,8 @@ kernel void int7pack_mm(
float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scales[kb * N + n]);
const float zero = float(zeros[kb * N + n]);
const float scale = float(scales[n * num_groups + kb]);
const float zero = float(zeros[n * num_groups + kb]);
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
Expand All @@ -58,15 +59,14 @@ kernel void int7pack_mm(
uchar b5 = B_ptr[7 * (k / 8) + 5];
uchar b6 = B_ptr[7 * (k / 8) + 6];

uchar w_val0 = b0 & 127;
uchar w_val1 = b1 & 127;
uchar w_val2 = b2 & 127;
uchar w_val3 = b3 & 127;
uchar w_val4 = b4 & 127;
uchar w_val5 = b5 & 127;
uchar w_val6 = b6 & 127;
uchar w_val7 = ((b0 & 128) >> 7) | ((b1 & 128) >> 6) | ((b2 & 128) >> 5) | ((b3 & 128) >> 4)
| ((b4 & 128) >> 3) | ((b5 & 128) >> 2) | ((b6 & 128) >> 1);
uchar w_val0 = (b0 & 0x7f);
uchar w_val1 = (b0 >> 7) | ((b1 & 0x3f) << 1);
uchar w_val2 = (b1 >> 6) | ((b2 & 0x1f) << 2);
uchar w_val3 = (b2 >> 5) | ((b3 & 0x0f) << 3);
uchar w_val4 = (b3 >> 4) | ((b4 & 0x07) << 4);
uchar w_val5 = (b4 >> 3) | ((b5 & 0x03) << 5);
uchar w_val6 = (b5 >> 2) | ((b6 & 0x01) << 6);
uchar w_val7 = (b6 >> 1);

rc += a_val0 * (scale * float(w_val0) + zero);
rc += a_val1 * (scale * float(w_val1) + zero);
Expand Down
Loading
Loading