Skip to content

Commit d19838e

Browse files
CUDA: FA support for Deepseek (Ampere or newer)
1 parent 93c4e23 commit d19838e

30 files changed

+786
-512
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516516
nullptr;
517517
}
518518

519-
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
519+
template<int D, int ncols1, int ncols2> // D == head size
520520
__launch_bounds__(D, 1)
521521
static __global__ void flash_attn_stream_k_fixup(
522522
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
@@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) {
665665
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
666666
GGML_ABORT("fatal error");
667667
} else {
668-
fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
668+
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
669669
fprintf(stderr, "Only f16 is supported.\n");
670670
GGML_ABORT("fatal error");
671671
}
672672
}
673673

674-
template <int D, int ncols1, int ncols2, int KQ_stride>
674+
template <int DV, int ncols1, int ncols2>
675675
void launch_fattn(
676676
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
677677
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
@@ -691,7 +691,7 @@ void launch_fattn(
691691

692692
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
693693
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
694-
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
694+
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
695695

696696
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
697697

@@ -752,10 +752,13 @@ void launch_fattn(
752752
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
753753

754754
const dim3 block_dim(warp_size, nwarps, 1);
755+
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
756+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
757+
755758
dim3 blocks_num;
756759
if (stream_k) {
757760
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
758-
const int max_blocks = 2*nsm;
761+
const int max_blocks = max_blocks_per_sm*nsm;
759762
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
760763
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
761764

@@ -767,14 +770,11 @@ void launch_fattn(
767770
blocks_num.y = 1;
768771
blocks_num.z = 1;
769772

770-
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
773+
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
771774
} else {
772775
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
773776
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
774777

775-
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
776-
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
777-
778778
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
779779
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
780780

@@ -851,19 +851,19 @@ void launch_fattn(
851851

852852
if (stream_k) {
853853
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
854-
const dim3 block_dim_combine(D, 1, 1);
854+
const dim3 block_dim_combine(DV, 1, 1);
855855
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
856856

857-
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
857+
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
858858
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
859859
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
860860
}
861861
} else if (parallel_blocks > 1) {
862-
const dim3 block_dim_combine(D, 1, 1);
862+
const dim3 block_dim_combine(DV, 1, 1);
863863
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
864864
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
865865

866-
flash_attn_combine_results<D>
866+
flash_attn_combine_results<DV>
867867
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
868868
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
869869
}

0 commit comments

Comments
 (0)