Skip to content

Commit 71bef66

Browse files
committed
cuda : graceful fallback for Mamba-1 models with weird embd size
1 parent 73de1fd commit 71bef66

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3329,12 +3329,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33293329
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
33303330
} else {
33313331
// Mamba
3332-
// (kernel only supports d_state == 16, n_group == 1, d_head == 1)
3333-
return op->src[0]->ne[0] == 16 && op->src[4]->ne[1] == 1 && op->src[0]->ne[1] == 1;
3332+
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
3333+
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
33343334
}
33353335
}
3336-
case GGML_OP_SSM_CONV:
3337-
return true;
3336+
case GGML_OP_SSM_CONV: {
3337+
// assumes d_inner % threads == 0
3338+
return op->src[0]->ne[1] % 128 == 0;
3339+
}
33383340
case GGML_OP_CONT:
33393341
return op->src[0]->type != GGML_TYPE_BF16;
33403342
case GGML_OP_DIAG_MASK_INF:

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
204204
const int threads = 128;
205205
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
206206
if (src3_nb1 == sizeof(float)) {
207-
// Mamba2
207+
// Mamba-2
208208
if (d_state == 128) {
209209
GGML_ASSERT(d_state % threads == 0);
210210
// NOTE: can be any power of two between 4 and 64
@@ -219,8 +219,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
219219
GGML_ABORT("doesn't support d_state!=128.");
220220
}
221221
} else {
222-
// Mamba1
223-
// todo: consider n_head cannot be divided, does this situation exist?
222+
// Mamba-1
224223
GGML_ASSERT(n_head % threads == 0);
225224
GGML_ASSERT(head_dim == 1);
226225
GGML_ASSERT(n_group == 1);

0 commit comments

Comments
 (0)