Skip to content

Commit 31f229c

Browse files
larger x tiles
1 parent 99291ce commit 31f229c

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ ifdef LLAMA_CUBLAS
169169
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
170170
OBJS += ggml-cuda.o
171171
NVCC = nvcc
172-
NVCCFLAGS = --forward-unknown-to-host-compiler
172+
NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
173173
ifdef CUDA_DOCKER_ARCH
174174
NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
175175
else

ggml-cuda.cu

+12-8
Original file line numberDiff line numberDiff line change
@@ -1662,24 +1662,24 @@ static __global__ void mul_mat_q(
16621662
const int tid_x = threadIdx.x;
16631663
const int tid_y = threadIdx.y;
16641664

1665-
const int row_dst_0 = blockIdx.x*WARP_SIZE;
1665+
const int row_dst_0 = 2*blockIdx.x*WARP_SIZE;
16661666
const int & row_x_0 = row_dst_0;
16671667
const int row_dst = row_dst_0 + tid_x;
16681668

16691669
const int col_dst_0 = blockIdx.y*WARP_SIZE;
16701670
const int & col_y_0 = col_dst_0;
16711671

1672-
__shared__ int tile_x_qs[WARP_SIZE][WARP_SIZE + 1];
1673-
__shared__ half tile_x_d[WARP_SIZE][WARP_SIZE/QI4_0];
1672+
__shared__ int tile_x_qs[2*WARP_SIZE][WARP_SIZE + 1];
1673+
__shared__ half tile_x_d[2*WARP_SIZE][WARP_SIZE/QI4_0];
16741674
__shared__ int tile_y_qs[WARP_SIZE][2*WARP_SIZE];
16751675
__shared__ half2 tile_y_ds[WARP_SIZE][2*WARP_SIZE/QI8_1];
1676-
float sum[4] = {0.0f};
1676+
float sum[2][4] = {0.0f};
16771677

16781678
for (int ib0 = 0; ib0 < blocks_per_row; ib0 += blocks_per_warp) {
16791679
const int ibx = tid_x / QI4_0;
16801680
const int iqsx = sizeof(int) * (tid_x % QI4_0);
16811681

1682-
for (int j = 0; j < WARP_SIZE; j += 8) {
1682+
for (int j = 0; j < 2*WARP_SIZE; j += 8) {
16831683
const block_q4_0 * __restrict__ bx = &x[(row_x_0 + j + tid_y)*blocks_per_row + ib0 + ibx];
16841684
memcpy(&tile_x_qs[j + tid_y][tid_x], &bx->qs[iqsx], sizeof(int));
16851685
tile_x_d[j + tid_y][ibx] = bx->d;
@@ -1706,9 +1706,12 @@ static __global__ void mul_mat_q(
17061706
for (int k = 0; k < WARP_SIZE; ++k) {
17071707
const int iqsy = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
17081708
for (int j = 0; j < WARP_SIZE; j += 8) {
1709-
sum[j/8] += vec_dot_q4_0_q8_1_impl(
1709+
sum[0][j/8] += vec_dot_q4_0_q8_1_impl(
17101710
tile_x_qs[tid_x][k], tile_y_qs[tid_y + j][iqsy + 0], tile_y_qs[tid_y + j][iqsy + (QI8_1/2)],
17111711
tile_x_d[tid_x][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]);
1712+
sum[1][j/8] += vec_dot_q4_0_q8_1_impl(
1713+
tile_x_qs[tid_x + WARP_SIZE][k], tile_y_qs[tid_y + j][iqsy + 0], tile_y_qs[tid_y + j][iqsy + (QI8_1/2)],
1714+
tile_x_d[tid_x + WARP_SIZE][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]);
17121715
}
17131716
}
17141717

@@ -1727,7 +1730,8 @@ static __global__ void mul_mat_q(
17271730
return;
17281731
}
17291732

1730-
dst[col_dst*nrows_dst + row_dst] = sum[j/8];
1733+
dst[col_dst*nrows_dst + row_dst] = sum[0][j/8];
1734+
dst[col_dst*nrows_dst + row_dst + WARP_SIZE] = sum[1][j/8];
17311735
}
17321736
}
17331737

@@ -2417,7 +2421,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
24172421
}
24182422

24192423
static void ggml_mul_mat_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
2420-
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
2424+
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
24212425
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
24222426
const dim3 block_nums(block_num_x, block_num_y, 1);
24232427
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);

0 commit comments

Comments
 (0)