Skip to content

Commit 25a0b90

Browse files
committed
cuda : try cublasGemmStridedBatchedEx
1 parent d415669 commit 25a0b90

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

ggml-cuda.cu

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7134,8 +7134,21 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
71347134
}
71357135
}
71367136
#else
7137-
// use cublasGemmBatchedEx
7138-
{
7137+
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
7138+
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
7139+
// use cublasGemmStridedBatchedEx
7140+
CUBLAS_CHECK(
7141+
cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7142+
ne01, ne11, ne10,
7143+
&alpha_f16, (char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), ne02*src0->nb[2], // strideA
7144+
(char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), ne12*src1->nb[2]/2, // strideB
7145+
&beta_f16, (char *) dst_f16, CUDA_R_16F, ne01, ne12* dst->nb[2]/2, // strideC
7146+
ne13,
7147+
CUBLAS_COMPUTE_16F,
7148+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7149+
printf("cublasGemmStridedBatchedEx\n");
7150+
} else {
7151+
// use cublasGemmBatchedEx
71397152
const int ne23 = ne12*ne13;
71407153

71417154
// TODO: avoid this alloc

0 commit comments

Comments
 (0)