Skip to content

Commit 3d297c1

Browse files
committed
cuda : add cublasGemmStridedBatchedEx for non-broadcasted cases
1 parent d415669 commit 3d297c1

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

ggml-cuda.cu

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#define cublasCreate hipblasCreate
3131
#define cublasGemmEx hipblasGemmEx
3232
#define cublasGemmBatchedEx hipblasGemmBatchedEx
33+
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
3334
#define cublasHandle_t hipblasHandle_t
3435
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
3536
#define cublasSetStream hipblasSetStream
@@ -7125,17 +7126,29 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
71257126
CUBLAS_CHECK(
71267127
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
71277128
ne01, ne11, ne10,
7128-
&alpha_f16, (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
7129-
(char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
7130-
&beta_f16, (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
7129+
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
7130+
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
7131+
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
71317132
CUBLAS_COMPUTE_16F,
71327133
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
71337134
}
71347135
}
71357136
}
71367137
#else
7137-
// use cublasGemmBatchedEx
7138-
{
7138+
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
7139+
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
7140+
// use cublasGemmStridedBatchedEx
7141+
CUBLAS_CHECK(
7142+
cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7143+
ne01, ne11, ne10,
7144+
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
7145+
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
7146+
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
7147+
ne12*ne13,
7148+
CUBLAS_COMPUTE_16F,
7149+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7150+
} else {
7151+
// use cublasGemmBatchedEx
71397152
const int ne23 = ne12*ne13;
71407153

71417154
// TODO: avoid this alloc

0 commit comments

Comments
 (0)