Skip to content

Commit 7355f75

Browse files
committed
Revert "CUDA: fastdiv, launch bounds for mmvq + q8_1 quant (ggml-org#15802)"
This reverts commit 5143fa8.
1 parent 46dcac9 commit 7355f75

File tree

3 files changed

+77
-67
lines changed

3 files changed

+77
-67
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,8 +575,6 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
575575
//
576576
// n/d = (mulhi(n, mp) + n) >> L;
577577
static const uint3 init_fastdiv_values(uint32_t d) {
578-
GGML_ASSERT(d != 0);
579-
580578
// compute L = ceil(log2(d));
581579
uint32_t L = 0;
582580
while (L < 32 && (uint32_t{ 1 } << L) < d) {

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 67 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,9 @@ template <ggml_type type, int ncols_dst>
141141
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142142
static __global__ void mul_mat_vec_q(
143143
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
144-
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
145-
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
146-
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
147-
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
144+
const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
145+
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
146+
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
148147

149148
constexpr int qk = ggml_cuda_type_traits<type>::qk;
150149
constexpr int qi = ggml_cuda_type_traits<type>::qi;
@@ -162,12 +161,12 @@ static __global__ void mul_mat_vec_q(
162161
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
163162

164163
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
165-
const uint32_t channel_dst = blockIdx.y;
166-
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
167-
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
168-
const uint32_t sample_dst = blockIdx.z;
169-
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
170-
const uint32_t sample_y = sample_dst;
164+
const int channel_dst = blockIdx.y;
165+
const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
166+
const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
167+
const int sample_dst = blockIdx.z;
168+
const int sample_x = sample_dst / sample_ratio;
169+
const int sample_y = sample_dst;
171170

172171
// partial sum for each thread
173172
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
@@ -248,80 +247,95 @@ static void mul_mat_vec_q_switch_ncols_dst(
248247
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
249248
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
250249

251-
const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
252-
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
253-
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
250+
const int channel_ratio = nchannels_dst / nchannels_x;
251+
const int sample_ratio = nsamples_dst / nsamples_x;
254252

255253
const int device = ggml_cuda_get_device();
256254
const int warp_size = ggml_cuda_info().devices[device].warp_size;
257255
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
258256

259257
GGML_ASSERT(!ids || ncols_dst == 1);
260258
switch (ncols_dst) {
261-
case 1: {
259+
case 1:
260+
{
262261
constexpr int c_ncols_dst = 1;
263262
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
264263
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
265-
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
266-
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
267-
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
268-
} break;
269-
case 2: {
264+
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
265+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
266+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
267+
break;
268+
}
269+
case 2:
270+
{
270271
constexpr int c_ncols_dst = 2;
271272
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
272273
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
273-
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
274-
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
275-
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
276-
} break;
277-
case 3: {
274+
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
275+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
276+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
277+
break;
278+
}
279+
case 3:
280+
{
278281
constexpr int c_ncols_dst = 3;
279282
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
280283
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
281-
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
282-
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
283-
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
284-
} break;
285-
case 4: {
284+
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
285+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
286+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
287+
break;
288+
}
289+
case 4:
290+
{
286291
constexpr int c_ncols_dst = 4;
287292
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
288293
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
289-
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
290-
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
291-
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
292-
} break;
293-
case 5: {
294+
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
295+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
296+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
297+
break;
298+
}
299+
case 5:
300+
{
294301
constexpr int c_ncols_dst = 5;
295302
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
296303
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
297-
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
298-
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
299-
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
300-
} break;
301-
case 6: {
304+
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
305+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
306+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
307+
break;
308+
}
309+
case 6:
310+
{
302311
constexpr int c_ncols_dst = 6;
303312
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
304313
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
305-
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
306-
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
307-
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
308-
} break;
309-
case 7: {
314+
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
315+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
316+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
317+
break;
318+
}
319+
case 7:
320+
{
310321
constexpr int c_ncols_dst = 7;
311322
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
312323
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
313-
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
314-
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
315-
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
316-
} break;
317-
case 8: {
324+
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
325+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
326+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
327+
break;
328+
}
329+
case 8:
330+
{
318331
constexpr int c_ncols_dst = 8;
319332
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
320333
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
321-
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
322-
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
323-
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
324-
} break;
334+
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
335+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
336+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
337+
break;
338+
}
325339
default:
326340
GGML_ABORT("fatal error");
327341
break;

ggml/src/ggml-cuda/quantize.cu

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,26 @@
11
#include "quantize.cuh"
22
#include <cstdint>
33

4-
__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)
54
static __global__ void quantize_q8_1(
65
const float * __restrict__ x, void * __restrict__ vy,
76
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
8-
const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
7+
const int64_t ne0, const int ne1, const int ne2) {
98
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
109

1110
if (i0 >= ne0) {
1211
return;
1312
}
1413

15-
const int64_t i3 = fastdiv(blockIdx.z, ne2);
16-
const int64_t i2 = blockIdx.z - i3*ne2.z;
1714
const int64_t i1 = blockIdx.y;
15+
const int64_t i2 = blockIdx.z % ne2;
16+
const int64_t i3 = blockIdx.z / ne2;
1817

1918
const int64_t & i00 = i0;
2019
const int64_t & i01 = i1;
2120
const int64_t & i02 = i2;
2221
const int64_t & i03 = i3;
2322

24-
const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0;
23+
const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
2524

2625
block_q8_1 * y = (block_q8_1 *) vy;
2726

@@ -32,10 +31,10 @@ static __global__ void quantize_q8_1(
3231
float amax = fabsf(xi);
3332
float sum = xi;
3433

35-
amax = warp_reduce_max<QK8_1>(amax);
36-
sum = warp_reduce_sum<QK8_1>(sum);
34+
amax = warp_reduce_max(amax);
35+
sum = warp_reduce_sum(sum);
3736

38-
const float d = amax / 127.0f;
37+
const float d = amax / 127;
3938
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
4039

4140
y[ib].qs[iqs] = q;
@@ -44,7 +43,8 @@ static __global__ void quantize_q8_1(
4443
return;
4544
}
4645

47-
y[ib].ds = make_half2(d, sum);
46+
reinterpret_cast<half&>(y[ib].ds.x) = d;
47+
reinterpret_cast<half&>(y[ib].ds.y) = sum;
4848
}
4949

5050
template <mmq_q8_1_ds_layout ds_layout>
@@ -152,12 +152,10 @@ void quantize_row_q8_1_cuda(
152152
GGML_ASSERT(!ids);
153153
GGML_ASSERT(ne0 % QK8_1 == 0);
154154

155-
const uint3 ne2_fastdiv = init_fastdiv_values(ne2);
156-
157155
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
158156
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
159157
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
160-
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
158+
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
161159
GGML_UNUSED(type_src0);
162160
}
163161

0 commit comments

Comments
 (0)