Skip to content

Commit 8088b6d

Browse files
committed
Refactor: wtype per tensor
1 parent ac54e00 commit 8088b6d

21 files changed

+203
-163
lines changed

clip.hpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -533,9 +533,12 @@ class CLIPEmbeddings : public GGMLBlock {
533533
int64_t vocab_size;
534534
int64_t num_positions;
535535

536-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
537-
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, wtype, embed_dim, vocab_size);
538-
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, num_positions);
536+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
537+
enum ggml_type token_wtype = (tensor_types.find(prefix + "token_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "token_embedding.weight"] : GGML_TYPE_F32;
538+
enum ggml_type position_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
539+
540+
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
541+
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
539542
}
540543

541544
public:
@@ -579,11 +582,14 @@ class CLIPVisionEmbeddings : public GGMLBlock {
579582
int64_t image_size;
580583
int64_t num_patches;
581584
int64_t num_positions;
585+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
586+
enum ggml_type patch_wtype = GGML_TYPE_F16; // tensor_types.find(prefix + "patch_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "patch_embedding.weight"] : GGML_TYPE_F16;
587+
enum ggml_type class_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "class_embedding") != tensor_types.end() ? tensor_types[prefix + "class_embedding"] : GGML_TYPE_F32;
588+
enum ggml_type position_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
582589

583-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
584-
params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, patch_size, patch_size, num_channels, embed_dim);
585-
params["class_embedding"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, embed_dim);
586-
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, num_positions);
590+
params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, patch_wtype, patch_size, patch_size, num_channels, embed_dim);
591+
params["class_embedding"] = ggml_new_tensor_1d(ctx, class_wtype, embed_dim);
592+
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
587593
}
588594

589595
public:
@@ -639,9 +645,10 @@ enum CLIPVersion {
639645

640646
class CLIPTextModel : public GGMLBlock {
641647
protected:
642-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
648+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
643649
if (version == OPEN_CLIP_VIT_BIGG_14) {
644-
params["text_projection"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size);
650+
enum ggml_type wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "text_projection") != tensor_types.end() ? tensor_types[prefix + "text_projection"] : GGML_TYPE_F32;
651+
params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size);
645652
}
646653
}
647654

@@ -779,9 +786,9 @@ class CLIPProjection : public UnaryBlock {
779786
int64_t out_features;
780787
bool transpose_weight;
781788

782-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
789+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
790+
enum ggml_type wtype = tensor_types.find(prefix + "weight") != tensor_types.end() ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
783791
if (transpose_weight) {
784-
LOG_ERROR("transpose_weight");
785792
params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features);
786793
} else {
787794
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
@@ -842,12 +849,13 @@ struct CLIPTextModelRunner : public GGMLRunner {
842849
CLIPTextModel model;
843850

844851
CLIPTextModelRunner(ggml_backend_t backend,
845-
ggml_type wtype,
852+
std::map<std::string, enum ggml_type>& tensor_types,
853+
const std::string prefix,
846854
CLIPVersion version = OPENAI_CLIP_VIT_L_14,
847855
int clip_skip_value = 1,
848856
bool with_final_ln = true)
849-
: GGMLRunner(backend, wtype), model(version, clip_skip_value, with_final_ln) {
850-
model.init(params_ctx, wtype);
857+
: GGMLRunner(backend), model(version, clip_skip_value, with_final_ln) {
858+
model.init(params_ctx, tensor_types, prefix);
851859
}
852860

853861
std::string get_desc() {
@@ -889,13 +897,13 @@ struct CLIPTextModelRunner : public GGMLRunner {
889897
struct ggml_tensor* embeddings = NULL;
890898

891899
if (num_custom_embeddings > 0 && custom_embeddings_data != NULL) {
892-
auto custom_embeddings = ggml_new_tensor_2d(compute_ctx,
893-
wtype,
894-
model.hidden_size,
895-
num_custom_embeddings);
900+
auto token_embed_weight = model.get_token_embed_weight();
901+
auto custom_embeddings = ggml_new_tensor_2d(compute_ctx,
902+
token_embed_weight->type,
903+
model.hidden_size,
904+
num_custom_embeddings);
896905
set_backend_tensor_data(custom_embeddings, custom_embeddings_data);
897906

898-
auto token_embed_weight = model.get_token_embed_weight();
899907
// concatenate custom embeddings
900908
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
901909
}

common.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,11 @@ class GEGLU : public GGMLBlock {
182182
int64_t dim_in;
183183
int64_t dim_out;
184184

185-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
185+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
186+
enum ggml_type wtype = (tensor_types.find(prefix + "proj.weight") != tensor_types.end()) ? tensor_types[prefix + "proj.weight"] : GGML_TYPE_F32;
187+
enum ggml_type bias_wtype = GGML_TYPE_F32;//(tensor_types.find(prefix + "proj.bias") != tensor_types.end()) ? tensor_types[prefix + "proj.bias"] : GGML_TYPE_F32;
186188
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
187-
params["proj.bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim_out * 2);
189+
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
188190
}
189191

190192
public:
@@ -433,8 +435,10 @@ class SpatialTransformer : public GGMLBlock {
433435

434436
class AlphaBlender : public GGMLBlock {
435437
protected:
436-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
437-
params["mix_factor"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
438+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
439+
// Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix
440+
enum ggml_type wtype = GGML_TYPE_F32;//(tensor_types.ypes.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32;
441+
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
438442
}
439443

440444
float get_alpha() {

conditioner.hpp

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ struct Conditioner {
4545
struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
4646
SDVersion version = VERSION_SD1;
4747
CLIPTokenizer tokenizer;
48-
ggml_type wtype;
4948
std::shared_ptr<CLIPTextModelRunner> text_model;
5049
std::shared_ptr<CLIPTextModelRunner> text_model2;
5150

@@ -56,24 +55,24 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
5655
std::vector<std::string> readed_embeddings;
5756

5857
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
59-
ggml_type wtype,
58+
std::map<std::string, enum ggml_type>& tensor_types,
6059
const std::string& embd_dir,
6160
SDVersion version = VERSION_SD1,
6261
int clip_skip = -1)
63-
: version(version), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
62+
: version(version), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir) {
6463
if (clip_skip <= 0) {
6564
clip_skip = 1;
6665
if (version == VERSION_SD2 || version == VERSION_SDXL) {
6766
clip_skip = 2;
6867
}
6968
}
7069
if (version == VERSION_SD1) {
71-
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip);
70+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip);
7271
} else if (version == VERSION_SD2) {
73-
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_H_14, clip_skip);
72+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip);
7473
} else if (version == VERSION_SDXL) {
75-
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
76-
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
74+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
75+
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
7776
}
7877
}
7978

@@ -136,14 +135,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
136135
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
137136
return false;
138137
}
139-
embd = ggml_new_tensor_2d(embd_ctx, wtype, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
138+
embd = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
140139
*dst_tensor = embd;
141140
return true;
142141
};
143142
model_loader.load_tensors(on_load, NULL);
144143
readed_embeddings.push_back(embd_name);
145144
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
146-
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(wtype)),
145+
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(embd->type)),
147146
embd->data,
148147
ggml_nbytes(embd));
149148
for (int i = 0; i < embd->ne[1]; i++) {
@@ -585,9 +584,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
585584
struct FrozenCLIPVisionEmbedder : public GGMLRunner {
586585
CLIPVisionModelProjection vision_model;
587586

588-
FrozenCLIPVisionEmbedder(ggml_backend_t backend, ggml_type wtype)
589-
: vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend, wtype) {
590-
vision_model.init(params_ctx, wtype);
587+
FrozenCLIPVisionEmbedder(ggml_backend_t backend, std::map<std::string, enum ggml_type>& tensor_types)
588+
: vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend) {
589+
vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer");
591590
}
592591

593592
std::string get_desc() {
@@ -622,7 +621,6 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
622621
};
623622

624623
struct SD3CLIPEmbedder : public Conditioner {
625-
ggml_type wtype;
626624
CLIPTokenizer clip_l_tokenizer;
627625
CLIPTokenizer clip_g_tokenizer;
628626
T5UniGramTokenizer t5_tokenizer;
@@ -631,15 +629,15 @@ struct SD3CLIPEmbedder : public Conditioner {
631629
std::shared_ptr<T5Runner> t5;
632630

633631
SD3CLIPEmbedder(ggml_backend_t backend,
634-
ggml_type wtype,
632+
std::map<std::string, enum ggml_type>& tensor_types,
635633
int clip_skip = -1)
636-
: wtype(wtype), clip_g_tokenizer(0) {
634+
: clip_g_tokenizer(0) {
637635
if (clip_skip <= 0) {
638636
clip_skip = 2;
639637
}
640-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
641-
clip_g = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
642-
t5 = std::make_shared<T5Runner>(backend, wtype);
638+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
639+
clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
640+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
643641
}
644642

645643
void set_clip_skip(int clip_skip) {
@@ -979,21 +977,19 @@ struct SD3CLIPEmbedder : public Conditioner {
979977
};
980978

981979
struct FluxCLIPEmbedder : public Conditioner {
982-
ggml_type wtype;
983980
CLIPTokenizer clip_l_tokenizer;
984981
T5UniGramTokenizer t5_tokenizer;
985982
std::shared_ptr<CLIPTextModelRunner> clip_l;
986983
std::shared_ptr<T5Runner> t5;
987984

988985
FluxCLIPEmbedder(ggml_backend_t backend,
989-
ggml_type wtype,
990-
int clip_skip = -1)
991-
: wtype(wtype) {
986+
std::map<std::string, enum ggml_type>& tensor_types,
987+
int clip_skip = -1) {
992988
if (clip_skip <= 0) {
993989
clip_skip = 2;
994990
}
995-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, true);
996-
t5 = std::make_shared<T5Runner>(backend, wtype);
991+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, true);
992+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
997993
}
998994

999995
void set_clip_skip(int clip_skip) {

control.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,12 @@ struct ControlNet : public GGMLRunner {
317317
bool guided_hint_cached = false;
318318

319319
ControlNet(ggml_backend_t backend,
320-
ggml_type wtype,
321320
SDVersion version = VERSION_SD1)
322-
: GGMLRunner(backend, wtype), control_net(version) {
323-
control_net.init(params_ctx, wtype);
321+
: GGMLRunner(backend), control_net(version) {
322+
}
323+
324+
void init_params(std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix) {
325+
control_net.init(params_ctx, tensor_types, prefix);
324326
}
325327

326328
~ControlNet() {

diffusion_model.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ struct UNetModel : public DiffusionModel {
3030
UNetModelRunner unet;
3131

3232
UNetModel(ggml_backend_t backend,
33-
ggml_type wtype,
33+
std::map<std::string, enum ggml_type>& tensor_types,
3434
SDVersion version = VERSION_SD1)
35-
: unet(backend, wtype, version) {
35+
: unet(backend, version) {
36+
unet.init_params(tensor_types, "model.diffusion_model");
3637
}
3738

3839
void alloc_params_buffer() {
@@ -79,9 +80,9 @@ struct MMDiTModel : public DiffusionModel {
7980
MMDiTRunner mmdit;
8081

8182
MMDiTModel(ggml_backend_t backend,
82-
ggml_type wtype,
83+
std::map<std::string, enum ggml_type>& tensor_types,
8384
SDVersion version = VERSION_SD3_2B)
84-
: mmdit(backend, wtype, version) {
85+
: mmdit(backend, tensor_types, "model.diffusion_model", version) {
8586
}
8687

8788
void alloc_params_buffer() {
@@ -128,9 +129,9 @@ struct FluxModel : public DiffusionModel {
128129
Flux::FluxRunner flux;
129130

130131
FluxModel(ggml_backend_t backend,
131-
ggml_type wtype,
132+
std::map<std::string, enum ggml_type>& tensor_types,
132133
SDVersion version = VERSION_FLUX_DEV)
133-
: flux(backend, wtype, version) {
134+
: flux(backend, tensor_types, "model.diffusion_model", version) {
134135
}
135136

136137
void alloc_params_buffer() {

esrgan.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,11 @@ struct ESRGAN : public GGMLRunner {
142142
int scale = 4;
143143
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
144144

145-
ESRGAN(ggml_backend_t backend,
146-
ggml_type wtype)
147-
: GGMLRunner(backend, wtype) {
148-
rrdb_net.init(params_ctx, wtype);
145+
ESRGAN(ggml_backend_t backend)
146+
: GGMLRunner(backend) {
147+
}
148+
void init_params(std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix) {
149+
rrdb_net.init(params_ctx, tensor_types, prefix);
149150
}
150151

151152
std::string get_desc() {

examples/cli/main.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -915,8 +915,7 @@ int main(int argc, const char* argv[]) {
915915
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
916916
if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) {
917917
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
918-
params.n_threads,
919-
params.wtype);
918+
params.n_threads);
920919

921920
if (upscaler_ctx == NULL) {
922921
printf("new_upscaler_ctx failed\n");

flux.hpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ namespace Flux {
3535
int64_t hidden_size;
3636
float eps;
3737

38-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
39-
params["scale"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
38+
void init_params(struct ggml_context* ctx, const std::string prefix, std::map<std::string, enum ggml_type>& tensor_types, std::map<std::string, struct ggml_tensor*>& params) {
39+
ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "scale") != tensor_types.end()) ? tensor_types[prefix + "scale"] : GGML_TYPE_F32;
40+
params["scale"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
4041
}
4142

4243
public:
@@ -807,14 +808,15 @@ namespace Flux {
807808
std::vector<float> pe_vec; // for cache
808809

809810
FluxRunner(ggml_backend_t backend,
810-
ggml_type wtype,
811-
SDVersion version = VERSION_FLUX_DEV)
812-
: GGMLRunner(backend, wtype) {
811+
std::map<std::string, enum ggml_type>& tensor_types = std::map<std::string, enum ggml_type>(),
812+
const std::string prefix = "",
813+
SDVersion version = VERSION_FLUX_DEV)
814+
: GGMLRunner(backend) {
813815
if (version == VERSION_FLUX_SCHNELL) {
814816
flux_params.guidance_embed = false;
815817
}
816818
flux = Flux(flux_params);
817-
flux.init(params_ctx, wtype);
819+
flux.init(params_ctx, tensor_types, prefix);
818820
}
819821

820822
std::string get_desc() {
@@ -929,7 +931,7 @@ namespace Flux {
929931
// ggml_backend_t backend = ggml_backend_cuda_init(0);
930932
ggml_backend_t backend = ggml_backend_cpu_init();
931933
ggml_type model_data_type = GGML_TYPE_Q8_0;
932-
std::shared_ptr<FluxRunner> flux = std::shared_ptr<FluxRunner>(new FluxRunner(backend, model_data_type));
934+
std::shared_ptr<FluxRunner> flux = std::shared_ptr<FluxRunner>(new FluxRunner(backend));
933935
{
934936
LOG_INFO("loading from '%s'", file_path.c_str());
935937

0 commit comments

Comments
 (0)