Skip to content

Commit 2b61001

Browse files
committed
CUDA: fix illegal memory access
1 parent da5f409 commit 2b61001

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

ggml-cuda.cu

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3154,23 +3154,27 @@ static __device__ float rope_ntkv2_ramp(const float low, const float high, const
31543154
return 1.0f - min(1.0f, max(0.0f, y));
31553155
}
31563156

3157+
struct rope_corr_factors {
3158+
float v[4];
3159+
};
3160+
31573161
// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
31583162
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
31593163
static __device__ float rope_ntkv2(
31603164
const float theta_base,
31613165
const float theta_linear,
31623166
const float theta_ntk,
3163-
const float corr_factors[4],
3167+
const rope_corr_factors corr_factors,
31643168
const int64_t i0,
31653169
const float ntk_factor,
31663170
const float ext_factor) {
31673171
float ramp_mix;
31683172
float theta;
31693173

3170-
ramp_mix = rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor;
3174+
ramp_mix = rope_ntkv2_ramp(corr_factors.v[0], corr_factors.v[1], i0) * ntk_factor;
31713175
theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
31723176

3173-
ramp_mix = rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * ext_factor;
3177+
ramp_mix = rope_ntkv2_ramp(corr_factors.v[2], corr_factors.v[3], i0) * ext_factor;
31743178
theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
31753179
return theta;
31763180
}
@@ -3187,7 +3191,7 @@ static __global__ void rope_f32(
31873191
const float theta_ntk_scale,
31883192
const float p0,
31893193
const int p_delta_rows,
3190-
const float corr_factors[4]) {
3194+
const rope_corr_factors corr_factors) {
31913195
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
31923196

31933197
if (col >= ncols) {
@@ -3817,7 +3821,7 @@ static void rope_f32_cuda(
38173821
const float theta_ntk_scale,
38183822
const float p0,
38193823
const int p_delta_rows,
3820-
const float corr_factors[4],
3824+
const rope_corr_factors corr_factors,
38213825
cudaStream_t stream) {
38223826
GGML_ASSERT(nrows % 2 == 0);
38233827
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
@@ -4546,8 +4550,8 @@ inline void ggml_cuda_op_rope(
45464550
} else {
45474551
const float p0 = (mode & 1) == 0 ? n_past : 0;
45484552
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
4549-
float corr_factors[4];
4550-
ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors);
4553+
rope_corr_factors corr_factors;
4554+
ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors.v);
45514555

45524556
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor, ext_factor, theta_scale,
45534557
theta_ntk_scale, p0, ne01, corr_factors, cudaStream_main);

0 commit comments

Comments
 (0)