@@ -1195,8 +1195,8 @@ void ggml_vk_rope(
1195
1195
const std::shared_ptr<kp::Tensor>& inB,
1196
1196
const std::shared_ptr<kp::Tensor>& out,
1197
1197
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1198
- ggml_type src0t, int32_t n_dims, int32_t mode,
1199
- float freq_base, float freq_scale,
1198
+ ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
1199
+ float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1200
1200
int32_t ne01, int32_t ne02, int32_t ne03,
1201
1201
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
1202
1202
int32_t ne0,
@@ -1224,15 +1224,15 @@ void ggml_vk_rope(
1224
1224
1225
1225
struct PushConstants {
1226
1226
uint32_t inAOff, inBOff, outOff;
1227
- int32_t n_dims, mode;
1228
- float freq_base, freq_scale;
1227
+ int32_t n_dims, mode, n_orig_ctx ;
1228
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ;
1229
1229
uint32_t nb00, nb01, nb02, nb03;
1230
1230
int32_t ne0;
1231
1231
uint32_t nb0, nb1, nb2, nb3;
1232
1232
} pushConsts {
1233
1233
safe_divide (inAOff, type_size), safe_divide (inBOff, 4 ), safe_divide (outOff, type_size),
1234
- n_dims, mode,
1235
- freq_base, freq_scale,
1234
+ n_dims, mode, n_orig_ctx,
1235
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1236
1236
nb00, nb01, nb02, nb03,
1237
1237
ne0,
1238
1238
nb0, nb1, nb2, nb3
@@ -1545,13 +1545,23 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
1545
1545
GGML_ASSERT (ne10 == ne02);
1546
1546
GGML_ASSERT (src0t == dstt);
1547
1547
// const int n_past = ((int32_t *) dst->op_params)[0];
1548
- const int n_dims = ((int32_t *) dst->op_params )[1 ];
1549
- const int mode = ((int32_t *) dst->op_params )[2 ];
1550
- float freq_base;
1551
- float freq_scale;
1552
- memcpy (&freq_base, (int32_t *) dst->op_params + 4 , sizeof (float ));
1553
- memcpy (&freq_scale, (int32_t *) dst->op_params + 5 , sizeof (float ));
1554
- ggml_vk_rope (seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, freq_base, freq_scale, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3);
1548
+ const int n_dims = ((int32_t *) dst->op_params )[1 ];
1549
+ const int mode = ((int32_t *) dst->op_params )[2 ];
1550
+ // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1551
+ const int n_orig_ctx = ((int32_t *) dst->op_params )[4 ];
1552
+
1553
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1554
+ memcpy (&freq_base, (int32_t *) dst->op_params + 5 , sizeof (float ));
1555
+ memcpy (&freq_scale, (int32_t *) dst->op_params + 6 , sizeof (float ));
1556
+ memcpy (&ext_factor, (int32_t *) dst->op_params + 7 , sizeof (float ));
1557
+ memcpy (&attn_factor, (int32_t *) dst->op_params + 8 , sizeof (float ));
1558
+ memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
1559
+ memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
1560
+ ggml_vk_rope (
1561
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
1562
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1563
+ ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1564
+ );
1555
1565
} break ;
1556
1566
case GGML_OP_DUP:
1557
1567
case GGML_OP_CPY:
0 commit comments