Skip to content

Commit 0faf92e

Browse files
committed
ggml : require mask when using ALiBi
ggml-ci
1 parent 397b1f8 commit 0faf92e

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

ggml.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5657,6 +5657,10 @@ static struct ggml_tensor * ggml_soft_max_impl(
56575657
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
56585658
}
56595659

5660+
if (max_bias > 0.0f) {
5661+
GGML_ASSERT(mask);
5662+
}
5663+
56605664
bool is_node = false;
56615665

56625666
if (a->grad) {
@@ -6440,6 +6444,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
64406444
float max_bias) {
64416445
GGML_ASSERT(ggml_can_mul_mat(k, q));
64426446
// TODO: check if vT can be multiplied by (k*qT)
6447+
64436448
if (mask) {
64446449
GGML_ASSERT(ggml_is_contiguous(mask));
64456450
GGML_ASSERT(mask->ne[2] == 1);
@@ -6449,6 +6454,10 @@ struct ggml_tensor * ggml_flash_attn_ext(
64496454
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
64506455
}
64516456

6457+
if (max_bias > 0.0f) {
6458+
GGML_ASSERT(mask);
6459+
}
6460+
64526461
bool is_node = false;
64536462

64546463
if (q->grad || k->grad || v->grad) {

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2126,6 +2126,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
21262126
#endif
21272127
for (bool mask : {false, true}) {
21282128
for (float max_bias : {0.0f, 8.0f}) {
2129+
if (!mask && max_bias > 0.0f) continue;
21292130
for (float scale : {1.0f, 0.1f}) {
21302131
for (int64_t ne0 : {16, 1024}) {
21312132
for (int64_t ne1 : {16, 1024}) {
@@ -2139,7 +2140,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
21392140

21402141
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
21412142
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
2142-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f));
21432143
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
21442144

21452145
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {

0 commit comments

Comments
 (0)