@@ -402,120 +402,86 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
402
402
403
403
void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
404
404
if (self_kq_mask || self_kq_mask_swa) {
405
- // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
406
- if (cparams.causal_attn ) {
407
- const int64_t n_kv = kv_self->n ;
408
- const int64_t n_tokens = ubatch->n_tokens ;
409
- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
410
- const int64_t n_seqs = ubatch->n_seqs ;
411
-
412
- float * data = nullptr ;
413
- float * data_swa = nullptr ;
414
-
415
- if (self_kq_mask) {
416
- GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
417
- data = (float *) self_kq_mask->data ;
418
- }
405
+ const int64_t n_kv = kv_self->n ;
406
+ const int64_t n_tokens = ubatch->n_tokens ;
407
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
408
+ const int64_t n_seqs = ubatch->n_seqs ;
419
409
420
- if (self_kq_mask_swa) {
421
- GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask_swa->buffer ));
422
- data_swa = (float *) self_kq_mask_swa->data ;
423
- }
410
+ float * data = nullptr ;
411
+ float * data_swa = nullptr ;
424
412
425
- // For causal attention, use only the previous KV cells
426
- // of the correct sequence for each token of the ubatch.
427
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
428
- for (int h = 0 ; h < 1 ; ++h) {
429
- for (int s = 0 ; s < n_seqs; ++s) {
430
- const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
413
+ if (self_kq_mask) {
414
+ GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
415
+ data = (float *) self_kq_mask->data ;
416
+ }
431
417
432
- for (int j = 0 ; j < n_seq_tokens; ++j) {
433
- const llama_pos pos = ubatch->pos [s*n_seq_tokens + j];
418
+ if (self_kq_mask_swa) {
419
+ GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask_swa->buffer ));
420
+ data_swa = (float *) self_kq_mask_swa->data ;
421
+ }
434
422
435
- for (int i = 0 ; i < n_kv; ++i) {
436
- float f;
437
- if (!kv_self->cells [i].has_seq_id (seq_id) || kv_self->cells [i].pos > pos) {
438
- f = -INFINITY;
423
+ // Use only the previous KV cells of the correct sequence for each token of the ubatch.
424
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
425
+ // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
426
+ // Causal mask:
427
+ // xxx-------
428
+ // xxxx------
429
+ // xxxxx-----
430
+ // Non-causal mask:
431
+ // xxxxx-----
432
+ // xxxxx-----
433
+ // xxxxx-----
434
+ // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
435
+ for (int h = 0 ; h < 1 ; ++h) {
436
+ for (int s = 0 ; s < n_seqs; ++s) {
437
+ const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
438
+
439
+ for (int j = 0 ; j < n_seq_tokens; ++j) {
440
+ const llama_pos pos = ubatch->pos [s*n_seq_tokens + j];
441
+ for (int i = 0 ; i < n_kv; ++i) {
442
+ float f;
443
+ // mask the token if:
444
+ if (!kv_self->cells [i].has_seq_id (seq_id) // not the correct sequence
445
+ || (cparams.causal_attn && kv_self->cells [i].pos > pos) // for causal, mask future tokens
446
+ ) {
447
+ f = -INFINITY;
448
+ } else {
449
+ if (hparams.use_alibi ) {
450
+ f = -std::abs (kv_self->cells [i].pos - pos);
439
451
} else {
440
- if (hparams.use_alibi ) {
441
- f = -std::abs (kv_self->cells [i].pos - pos);
442
- } else {
443
- f = 0 .0f ;
444
- }
445
- }
446
-
447
- if (data) {
448
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
452
+ f = 0 .0f ;
449
453
}
454
+ }
450
455
451
- // may need to cut off old tokens for sliding window
452
- if (data_swa) {
453
- if (pos - kv_self->cells [i].pos >= (int32_t )hparams.n_swa ) {
454
- f = -INFINITY;
455
- }
456
- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
457
- }
456
+ if (data) {
457
+ data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
458
458
}
459
- }
460
- }
461
459
462
- if (data) {
463
- for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
464
- for (int j = 0 ; j < n_kv; ++j) {
465
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
460
+ // may need to cut off old tokens for sliding window
461
+ if (data_swa) {
462
+ if (pos - kv_self->cells [i].pos >= (int32_t )hparams.n_swa ) {
463
+ f = -INFINITY;
464
+ }
465
+ data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
466
466
}
467
467
}
468
468
}
469
+ }
469
470
470
- if (data_swa) {
471
- for ( int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i ) {
472
- for (int j = 0 ; j < n_kv ; ++j ) {
473
- data_swa[h*( n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
474
- }
471
+ // mask padded tokens
472
+ if (data ) {
473
+ for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD) ; ++i ) {
474
+ for ( int j = 0 ; j < n_kv; ++j) {
475
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
475
476
}
476
477
}
477
478
}
478
- } else {
479
- const int64_t n_tokens = ubatch->n_tokens ;
480
- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
481
- const int64_t n_seqs = ubatch->n_seqs ;
482
- // when using kv cache, the mask needs to match the kv cache size
483
- const int64_t n_stride = n_tokens;
484
479
485
- GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
486
-
487
- float * data = (float *) self_kq_mask->data ;
488
-
489
- for (int h = 0 ; h < 1 ; ++h) {
490
- for (int s1 = 0 ; s1 < n_seqs; ++s1) {
491
- const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
492
-
493
- for (int j = 0 ; j < n_seq_tokens; ++j) {
494
- const int32_t tj = s1*n_seq_tokens + j;
495
-
496
- for (int s0 = 0 ; s0 < n_seqs; ++s0) {
497
- for (int i = 0 ; i < n_seq_tokens; ++i) {
498
- const int32_t ti = s0*n_seq_tokens + i;
499
- float f = -INFINITY;
500
-
501
- for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
502
- if (ubatch->seq_id [s0][s] == seq_id) {
503
- if (hparams.use_alibi ) {
504
- f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
505
- } else {
506
- f = 0 .0f ;
507
- }
508
- break ;
509
- }
510
- }
511
-
512
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
513
- }
514
- }
515
-
516
- for (int i = n_tokens; i < n_stride; ++i) {
517
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
518
- }
480
+ // mask padded tokens
481
+ if (data_swa) {
482
+ for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
483
+ for (int j = 0 ; j < n_kv; ++j) {
484
+ data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
519
485
}
520
486
}
521
487
}
0 commit comments