Skip to content

Commit d891942

Browse files
CUDA: fix FlashAttention on Turing (#13415)
1 parent 7fef117 commit d891942

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
546546
const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
547547
const int i0_diff = i0_stop - i0_start;
548548

549-
if (nstages == 1) {
549+
if (nstages <= 1) {
550550
constexpr bool use_cp_async = nstages == 1;
551551
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
552552
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);

0 commit comments

Comments
 (0)