@@ -104,6 +104,7 @@ static std::string format(const char * fmt, ...) {
104
104
#define TN_POS_EMBD " %s.position_embd.weight"
105
105
#define TN_CLASS_EMBD " v.class_embd"
106
106
#define TN_PATCH_EMBD " v.patch_embd.weight"
107
+ #define TN_PATCH_BIAS " v.patch_embd.bias"
107
108
#define TN_ATTN_K " %s.blk.%d.attn_k.%s"
108
109
#define TN_ATTN_Q " %s.blk.%d.attn_q.%s"
109
110
#define TN_ATTN_V " %s.blk.%d.attn_v.%s"
@@ -425,6 +426,7 @@ struct clip_vision_model {
425
426
// embeddings
426
427
struct ggml_tensor * class_embedding;
427
428
struct ggml_tensor * patch_embeddings;
429
+ struct ggml_tensor * patch_bias;
428
430
struct ggml_tensor * position_embeddings;
429
431
430
432
struct ggml_tensor * pre_ln_w;
@@ -501,6 +503,11 @@ struct clip_ctx {
501
503
bool use_gelu = false ;
502
504
int32_t ftype = 1 ;
503
505
506
+ bool has_class_embedding = true ;
507
+ bool has_pre_norm = true ;
508
+ bool has_post_norm = false ;
509
+ bool has_patch_bias = false ;
510
+
504
511
struct gguf_context * ctx_gguf;
505
512
struct ggml_context * ctx_data;
506
513
@@ -526,7 +533,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
526
533
const int patch_size = hparams.patch_size ;
527
534
const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
528
535
const int num_patches_per_side = image_size / patch_size; GGML_UNUSED (num_patches_per_side);
529
- const int num_positions = num_patches + 1 ;
536
+ const int num_positions = num_patches + (ctx-> has_class_embedding ? 1 : 0 ) ;
530
537
const int hidden_size = hparams.hidden_size ;
531
538
const int n_head = hparams.n_head ;
532
539
const int d_head = hidden_size / n_head;
@@ -557,16 +564,23 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
557
564
inp = ggml_reshape_3d (ctx0, inp, num_patches, hidden_size, batch_size);
558
565
inp = ggml_cont (ctx0, ggml_permute (ctx0, inp, 1 , 0 , 2 , 3 ));
559
566
567
+ if (ctx->has_patch_bias ) {
568
+ // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
569
+ inp = ggml_add (ctx0, inp, model.patch_bias );
570
+ }
571
+
560
572
// concat class_embeddings and patch_embeddings
561
- struct ggml_tensor * embeddings = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
573
+ struct ggml_tensor * embeddings = inp;
574
+ if (ctx->has_class_embedding ) {
575
+ embeddings = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
576
+ embeddings = ggml_acc (ctx0, embeddings, model.class_embedding ,
577
+ embeddings->nb [1 ], embeddings->nb [2 ], embeddings->nb [3 ], 0 );
578
+ embeddings = ggml_acc (ctx0, embeddings, inp,
579
+ embeddings->nb [1 ], embeddings->nb [2 ], embeddings->nb [3 ], model.class_embedding ->nb [1 ]);
580
+ }
562
581
ggml_set_name (embeddings, " embeddings" );
563
582
ggml_set_input (embeddings);
564
583
565
- embeddings = ggml_acc (ctx0, embeddings, model.class_embedding ,
566
- embeddings->nb [1 ], embeddings->nb [2 ], embeddings->nb [3 ], 0 );
567
-
568
- embeddings = ggml_acc (ctx0, embeddings, inp,
569
- embeddings->nb [1 ], embeddings->nb [2 ], embeddings->nb [3 ], model.class_embedding ->nb [1 ]);
570
584
571
585
struct ggml_tensor * positions = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, num_positions);
572
586
ggml_set_name (positions, " positions" );
@@ -576,7 +590,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
576
590
ggml_add (ctx0, embeddings, ggml_get_rows (ctx0, model.position_embeddings , positions));
577
591
578
592
// pre-layernorm
579
- {
593
+ if (ctx-> has_pre_norm ) {
580
594
embeddings = ggml_norm (ctx0, embeddings, eps);
581
595
ggml_set_name (embeddings, " pre_ln" );
582
596
@@ -664,6 +678,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
664
678
embeddings = cur;
665
679
}
666
680
681
+ // post-layernorm
682
+ if (ctx->has_post_norm ) {
683
+ embeddings = ggml_norm (ctx0, embeddings, eps);
684
+ ggml_set_name (embeddings, " post_ln" );
685
+
686
+ embeddings = ggml_add (ctx0, ggml_mul (ctx0, embeddings, model.post_ln_w ), model.post_ln_b );
687
+ }
688
+
667
689
// llava projector
668
690
{
669
691
embeddings = ggml_reshape_2d (ctx0, embeddings, embeddings->ne [0 ], embeddings->ne [1 ]);
@@ -1148,12 +1170,39 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
1148
1170
1149
1171
}
1150
1172
1173
+ try {
1174
+ vision_model.class_embedding = get_tensor (new_clip->ctx_data , TN_CLASS_EMBD);
1175
+ new_clip->has_class_embedding = true ;
1176
+ } catch (const std::exception & e) {
1177
+ new_clip->has_class_embedding = false ;
1178
+ }
1179
+
1180
+ try {
1181
+ vision_model.pre_ln_w = get_tensor (new_clip->ctx_data , format (TN_LN_PRE, " v" , " weight" ));
1182
+ vision_model.pre_ln_b = get_tensor (new_clip->ctx_data , format (TN_LN_PRE, " v" , " bias" ));
1183
+ new_clip->has_pre_norm = true ;
1184
+ } catch (std::exception & e) {
1185
+ new_clip->has_pre_norm = false ;
1186
+ }
1187
+
1188
+ try {
1189
+ vision_model.post_ln_w = get_tensor (new_clip->ctx_data , format (TN_LN_POST, " v" , " weight" ));
1190
+ vision_model.post_ln_b = get_tensor (new_clip->ctx_data , format (TN_LN_POST, " v" , " bias" ));
1191
+ new_clip->has_post_norm = true ;
1192
+ } catch (std::exception & e) {
1193
+ new_clip->has_post_norm = false ;
1194
+ }
1195
+
1196
+ try {
1197
+ vision_model.patch_bias = get_tensor (new_clip->ctx_data , TN_PATCH_BIAS);
1198
+ new_clip->has_patch_bias = true ;
1199
+ } catch (std::exception & e) {
1200
+ new_clip->has_patch_bias = false ;
1201
+ }
1202
+
1151
1203
try {
1152
1204
vision_model.patch_embeddings = get_tensor (new_clip->ctx_data , TN_PATCH_EMBD);
1153
- vision_model.class_embedding = get_tensor (new_clip->ctx_data , TN_CLASS_EMBD);
1154
1205
vision_model.position_embeddings = get_tensor (new_clip->ctx_data , format (TN_POS_EMBD, " v" ));
1155
- vision_model.pre_ln_w = get_tensor (new_clip->ctx_data , format (TN_LN_PRE, " v" , " weight" ));
1156
- vision_model.pre_ln_b = get_tensor (new_clip->ctx_data , format (TN_LN_PRE, " v" , " bias" ));
1157
1206
} catch (const std::exception & e) {
1158
1207
LOG_TEE (" %s: failed to load vision model tensors\n " , __func__);
1159
1208
}
0 commit comments