@@ -58,10 +58,10 @@ static __global__ void rope(
58
58
dst[i + 1 ] = x0*sin_theta + x1*cos_theta;
59
59
}
60
60
61
- template <typename T, bool has_pos>
61
+ template <typename T, bool has_pos, bool has_freq_facs >
62
62
static __global__ void rope_neox (
63
63
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
64
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
64
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
65
65
) {
66
66
const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
67
67
@@ -88,7 +88,9 @@ static __global__ void rope_neox(
88
88
float cur_rot = inv_ndims * ic - ib;
89
89
90
90
const int p = has_pos ? pos[i2] : 0 ;
91
- const float theta_base = p*freq_scale*powf (theta_scale, col/2 .0f );
91
+ const float freq_factor = has_freq_facs ? freq_factors[ic/2 ] : 1 .0f ;
92
+
93
+ const float theta_base = p*freq_scale*powf (theta_scale, col/2 .0f )/freq_factor;
92
94
93
95
float cos_theta, sin_theta;
94
96
rope_yarn (theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -164,7 +166,7 @@ static void rope_cuda(
164
166
template <typename T>
165
167
static void rope_neox_cuda (
166
168
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
167
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
169
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
168
170
) {
169
171
GGML_ASSERT (ncols % 2 == 0 );
170
172
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
@@ -175,15 +177,29 @@ static void rope_neox_cuda(
175
177
const float inv_ndims = -1 .0f / n_dims;
176
178
177
179
if (pos == nullptr ) {
178
- rope_neox<T, false ><<<block_nums, block_dims, 0 , stream>>> (
179
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
180
- theta_scale, inv_ndims
181
- );
180
+ if (freq_factors == nullptr ) {
181
+ rope_neox<T, false , false ><<<block_nums, block_dims, 0 , stream>>> (
182
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
183
+ theta_scale, inv_ndims, freq_factors
184
+ );
185
+ } else {
186
+ rope_neox<T, false , true ><<<block_nums, block_dims, 0 , stream>>> (
187
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
188
+ theta_scale, inv_ndims, freq_factors
189
+ );
190
+ }
182
191
} else {
183
- rope_neox<T, true ><<<block_nums, block_dims, 0 , stream>>> (
184
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
185
- theta_scale, inv_ndims
186
- );
192
+ if (freq_factors == nullptr ) {
193
+ rope_neox<T, true , false ><<<block_nums, block_dims, 0 , stream>>> (
194
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
195
+ theta_scale, inv_ndims, freq_factors
196
+ );
197
+ } else {
198
+ rope_neox<T, true , true ><<<block_nums, block_dims, 0 , stream>>> (
199
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
200
+ theta_scale, inv_ndims, freq_factors
201
+ );
202
+ }
187
203
}
188
204
}
189
205
@@ -214,24 +230,27 @@ static void rope_cuda_f32(
214
230
215
231
static void rope_neox_cuda_f16 (
216
232
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
217
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
233
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
218
234
219
- rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
235
+ rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
220
236
}
221
237
222
238
static void rope_neox_cuda_f32 (
223
239
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
224
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
240
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
225
241
) {
226
242
227
- rope_neox_cuda<float >(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
243
+ rope_neox_cuda<float >(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
228
244
}
229
245
230
246
void ggml_cuda_op_rope (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
231
247
const ggml_tensor * src0 = dst->src [0 ];
232
248
const ggml_tensor * src1 = dst->src [1 ];
249
+ const ggml_tensor * src2 = dst->src [2 ];
250
+
233
251
const float * src0_d = (const float *)src0->data ;
234
252
const float * src1_d = (const float *)src1->data ;
253
+
235
254
float * dst_d = (float *)dst->data ;
236
255
cudaStream_t stream = ctx.stream ();
237
256
@@ -241,7 +260,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
241
260
242
261
const int64_t ne00 = src0->ne [0 ];
243
262
const int64_t ne01 = src0->ne [1 ];
244
- const int64_t ne2 = dst->ne [2 ];
245
263
const int64_t nrows = ggml_nrows (src0);
246
264
247
265
// const int n_past = ((int32_t *) dst->op_params)[0];
@@ -259,16 +277,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
259
277
memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
260
278
memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
261
279
280
+ const float * freq_factors = nullptr ;
262
281
const int32_t * pos = nullptr ;
263
- if ((mode & 1 ) == 0 ) {
264
- GGML_ASSERT (src1->type == GGML_TYPE_I32);
265
- GGML_ASSERT (src1->ne [0 ] == ne2);
266
- pos = (const int32_t *) src1_d;
267
- }
268
282
269
283
const bool is_neox = mode & 2 ;
270
284
const bool is_glm = mode & 4 ;
271
285
286
+ if (is_neox) {
287
+ pos = (const int32_t *) src1_d;
288
+
289
+ if (src2 != nullptr ) {
290
+ freq_factors = (const float *) src2->data ;
291
+ }
292
+ } else {
293
+ GGML_ASSERT (src2 == nullptr && " TODO: freq_factors not implemented for !is_neox" );
294
+ }
295
+
272
296
rope_corr_dims corr_dims;
273
297
ggml_rope_yarn_corr_dims (n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v );
274
298
@@ -280,12 +304,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
280
304
if (src0->type == GGML_TYPE_F32) {
281
305
rope_neox_cuda_f32 (
282
306
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
283
- attn_factor, corr_dims, stream
307
+ attn_factor, corr_dims, freq_factors, stream
284
308
);
285
309
} else if (src0->type == GGML_TYPE_F16) {
286
310
rope_neox_cuda_f16 (
287
311
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
288
- attn_factor, corr_dims, stream
312
+ attn_factor, corr_dims, freq_factors, stream
289
313
);
290
314
} else {
291
315
GGML_ASSERT (false );
0 commit comments