|
30 | 30 | #define cublasCreate hipblasCreate
|
31 | 31 | #define cublasGemmEx hipblasGemmEx
|
32 | 32 | #define cublasGemmBatchedEx hipblasGemmBatchedEx
|
| 33 | +#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx |
33 | 34 | #define cublasHandle_t hipblasHandle_t
|
34 | 35 | #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
35 | 36 | #define cublasSetStream hipblasSetStream
|
@@ -7125,17 +7126,29 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
7125 | 7126 | CUBLAS_CHECK(
|
7126 | 7127 | cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
7127 | 7128 | ne01, ne11, ne10,
|
7128 |
| - &alpha_f16, (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), |
7129 |
| - (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), |
7130 |
| - &beta_f16, (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01, |
| 7129 | + &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), |
| 7130 | + (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), |
| 7131 | + &beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01, |
7131 | 7132 | CUBLAS_COMPUTE_16F,
|
7132 | 7133 | CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
7133 | 7134 | }
|
7134 | 7135 | }
|
7135 | 7136 | }
|
7136 | 7137 | #else
|
7137 |
| - // use cublasGemmBatchedEx |
7138 |
| - { |
| 7138 | + if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { |
| 7139 | + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 |
| 7140 | + // use cublasGemmStridedBatchedEx |
| 7141 | + CUBLAS_CHECK( |
| 7142 | + cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, |
| 7143 | + ne01, ne11, ne10, |
| 7144 | + &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA |
| 7145 | + (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB |
| 7146 | + &beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC |
| 7147 | + ne12*ne13, |
| 7148 | + CUBLAS_COMPUTE_16F, |
| 7149 | + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); |
| 7150 | + } else { |
| 7151 | + // use cublasGemmBatchedEx |
7139 | 7152 | const int ne23 = ne12*ne13;
|
7140 | 7153 |
|
7141 | 7154 | // TODO: avoid this alloc
|
|
0 commit comments