@@ -1662,24 +1662,24 @@ static __global__ void mul_mat_q(
1662
1662
const int tid_x = threadIdx .x ;
1663
1663
const int tid_y = threadIdx .y ;
1664
1664
1665
- const int row_dst_0 = blockIdx .x *WARP_SIZE;
1665
+ const int row_dst_0 = 2 * blockIdx .x *WARP_SIZE;
1666
1666
const int & row_x_0 = row_dst_0;
1667
1667
const int row_dst = row_dst_0 + tid_x;
1668
1668
1669
1669
const int col_dst_0 = blockIdx .y *WARP_SIZE;
1670
1670
const int & col_y_0 = col_dst_0;
1671
1671
1672
- __shared__ int tile_x_qs[WARP_SIZE][WARP_SIZE + 1 ];
1673
- __shared__ half tile_x_d[WARP_SIZE][WARP_SIZE/QI4_0];
1672
+ __shared__ int tile_x_qs[2 * WARP_SIZE][WARP_SIZE + 1 ];
1673
+ __shared__ half tile_x_d[2 * WARP_SIZE][WARP_SIZE/QI4_0];
1674
1674
__shared__ int tile_y_qs[WARP_SIZE][2 *WARP_SIZE];
1675
1675
__shared__ half2 tile_y_ds[WARP_SIZE][2 *WARP_SIZE/QI8_1];
1676
- float sum[4 ] = {0 .0f };
1676
+ float sum[2 ][ 4 ] = {0 .0f };
1677
1677
1678
1678
for (int ib0 = 0 ; ib0 < blocks_per_row; ib0 += blocks_per_warp) {
1679
1679
const int ibx = tid_x / QI4_0;
1680
1680
const int iqsx = sizeof (int ) * (tid_x % QI4_0);
1681
1681
1682
- for (int j = 0 ; j < WARP_SIZE; j += 8 ) {
1682
+ for (int j = 0 ; j < 2 * WARP_SIZE; j += 8 ) {
1683
1683
const block_q4_0 * __restrict__ bx = &x[(row_x_0 + j + tid_y)*blocks_per_row + ib0 + ibx];
1684
1684
memcpy (&tile_x_qs[j + tid_y][tid_x], &bx->qs [iqsx], sizeof (int ));
1685
1685
tile_x_d[j + tid_y][ibx] = bx->d ;
@@ -1706,9 +1706,12 @@ static __global__ void mul_mat_q(
1706
1706
for (int k = 0 ; k < WARP_SIZE; ++k) {
1707
1707
const int iqsy = k % (QI8_1/2 ) + QI8_1 * (k / (QI8_1/2 ));
1708
1708
for (int j = 0 ; j < WARP_SIZE; j += 8 ) {
1709
- sum[j/8 ] += vec_dot_q4_0_q8_1_impl (
1709
+ sum[0 ][ j/8 ] += vec_dot_q4_0_q8_1_impl (
1710
1710
tile_x_qs[tid_x][k], tile_y_qs[tid_y + j][iqsy + 0 ], tile_y_qs[tid_y + j][iqsy + (QI8_1/2 )],
1711
1711
tile_x_d[tid_x][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]);
1712
+ sum[1 ][j/8 ] += vec_dot_q4_0_q8_1_impl (
1713
+ tile_x_qs[tid_x + WARP_SIZE][k], tile_y_qs[tid_y + j][iqsy + 0 ], tile_y_qs[tid_y + j][iqsy + (QI8_1/2 )],
1714
+ tile_x_d[tid_x + WARP_SIZE][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]);
1712
1715
}
1713
1716
}
1714
1717
@@ -1727,7 +1730,8 @@ static __global__ void mul_mat_q(
1727
1730
return ;
1728
1731
}
1729
1732
1730
- dst[col_dst*nrows_dst + row_dst] = sum[j/8 ];
1733
+ dst[col_dst*nrows_dst + row_dst] = sum[0 ][j/8 ];
1734
+ dst[col_dst*nrows_dst + row_dst + WARP_SIZE] = sum[1 ][j/8 ];
1731
1735
}
1732
1736
}
1733
1737
@@ -2417,7 +2421,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
2417
2421
}
2418
2422
2419
2423
static void ggml_mul_mat_q4_0_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
2420
- const int block_num_x = (nrows_x + WARP_SIZE - 1 ) / WARP_SIZE;
2424
+ const int block_num_x = (nrows_x + 2 * WARP_SIZE - 1 ) / ( 2 * WARP_SIZE) ;
2421
2425
const int block_num_y = (ncols_y + WARP_SIZE - 1 ) / WARP_SIZE;
2422
2426
const dim3 block_nums (block_num_x, block_num_y, 1 );
2423
2427
const dim3 block_dims (WARP_SIZE, WARP_SIZE/4 , 1 );
0 commit comments