@@ -6868,43 +6868,52 @@ static void mul_mat_vec_q_cuda(
6868
6868
6869
6869
const int32_t config = ncols_y | (nwarps << 16 );
6870
6870
6871
- switch (config) {
6872
- case 0x00010001 :
6873
- mul_mat_vec_q<1 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6874
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6875
- break ;
6876
- case 0x00010002 :
6877
- mul_mat_vec_q<1 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6878
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6879
- break ;
6880
- case 0x00010003 :
6881
- mul_mat_vec_q<1 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6882
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6883
- break ;
6884
- case 0x00010004 :
6885
- mul_mat_vec_q<1 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6886
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6887
- break ;
6888
- case 0x00040001 :
6889
- mul_mat_vec_q<4 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6890
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6891
- break ;
6892
- case 0x00040002 :
6893
- mul_mat_vec_q<4 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6894
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6895
- break ;
6896
- case 0x00040003 :
6897
- mul_mat_vec_q<4 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6898
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6899
- break ;
6900
- case 0x00040004 :
6901
- mul_mat_vec_q<4 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6902
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6903
- break ;
6871
+ switch (nwarps) {
6872
+ case 1 : switch (ncols_y) {
6873
+ case 1 :
6874
+ mul_mat_vec_q<1 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6875
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6876
+ break ;
6877
+ case 2 :
6878
+ mul_mat_vec_q<1 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6879
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6880
+ break ;
6881
+ case 3 :
6882
+ mul_mat_vec_q<1 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6883
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6884
+ break ;
6885
+ case 4 :
6886
+ mul_mat_vec_q<1 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6887
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6888
+ break ;
6889
+ default :
6890
+ GGML_ASSERT (false );
6891
+ break ;
6892
+ } break ;
6893
+ case 4 : switch (ncols_y) {
6894
+ case 1 :
6895
+ mul_mat_vec_q<4 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6896
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6897
+ break ;
6898
+ case 2 :
6899
+ mul_mat_vec_q<4 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6900
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6901
+ break ;
6902
+ case 3 :
6903
+ mul_mat_vec_q<4 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6904
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6905
+ break ;
6906
+ case 4 :
6907
+ mul_mat_vec_q<4 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6908
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6909
+ break ;
6910
+ default :
6911
+ GGML_ASSERT (false );
6912
+ break ;
6913
+ } break ;
6914
+
6904
6915
default :
6905
6916
GGML_ASSERT (false );
6906
- // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
6907
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6908
6917
break ;
6909
6918
}
6910
6919
}
0 commit comments