Skip to content

Commit fae826f

Browse files
committed
Fix failed assertions while running Falcon Mamba
1 parent 061e520 commit fae826f

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

ggml/src/ggml-cuda/norm.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
153153
}
154154

155155
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
156-
GGML_ASSERT(ncols % WARP_SIZE == 0);
156+
GGML_ASSERT(ncols % WARP_SIZE == 0 || ncols < WARP_SIZE);
157157
if (ncols < 1024) {
158-
const dim3 block_dims(WARP_SIZE, 1, 1);
158+
const dim3 block_dims(min(ncols, WARP_SIZE), 1, 1);
159159
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
160160
} else {
161161
const dim3 block_dims(1024, 1, 1);

src/llama.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9119,9 +9119,9 @@ static struct ggml_tensor * llm_build_mamba(
91199119

91209120
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
91219121
if (ssm_dt_b_c_rms) {
9122-
dt = ggml_rms_norm(ctx, dt, norm_rms_eps);
9123-
B = ggml_rms_norm(ctx, B, norm_rms_eps);
9124-
C = ggml_rms_norm(ctx, C, norm_rms_eps);
9122+
dt = ggml_rms_norm(ctx, ggml_cont(ctx, dt), norm_rms_eps);
9123+
B = ggml_rms_norm(ctx, ggml_cont(ctx, B), norm_rms_eps);
9124+
C = ggml_rms_norm(ctx, ggml_cont(ctx, C), norm_rms_eps);
91259125
}
91269126

91279127
// {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}

0 commit comments

Comments
 (0)