Skip to content

Refactor: Flexible model architecture for dit models (Flux & SD3) #490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 30, 2024
46 changes: 27 additions & 19 deletions clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,9 +545,12 @@ class CLIPEmbeddings : public GGMLBlock {
int64_t vocab_size;
int64_t num_positions;

void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, wtype, embed_dim, vocab_size);
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, num_positions);
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type token_wtype = (tensor_types.find(prefix + "token_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "token_embedding.weight"] : GGML_TYPE_F32;
enum ggml_type position_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;

params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
}

public:
Expand Down Expand Up @@ -591,11 +594,14 @@ class CLIPVisionEmbeddings : public GGMLBlock {
int64_t image_size;
int64_t num_patches;
int64_t num_positions;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type patch_wtype = GGML_TYPE_F16; // tensor_types.find(prefix + "patch_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "patch_embedding.weight"] : GGML_TYPE_F16;
enum ggml_type class_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "class_embedding") != tensor_types.end() ? tensor_types[prefix + "class_embedding"] : GGML_TYPE_F32;
enum ggml_type position_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;

void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, patch_size, patch_size, num_channels, embed_dim);
params["class_embedding"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, embed_dim);
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, num_positions);
params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, patch_wtype, patch_size, patch_size, num_channels, embed_dim);
params["class_embedding"] = ggml_new_tensor_1d(ctx, class_wtype, embed_dim);
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
}

public:
Expand Down Expand Up @@ -651,9 +657,10 @@ enum CLIPVersion {

class CLIPTextModel : public GGMLBlock {
protected:
void init_params(struct ggml_context* ctx, ggml_type wtype) {
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
if (version == OPEN_CLIP_VIT_BIGG_14) {
params["text_projection"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size);
enum ggml_type wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "text_projection") != tensor_types.end() ? tensor_types[prefix + "text_projection"] : GGML_TYPE_F32;
params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size);
}
}

Expand Down Expand Up @@ -798,9 +805,9 @@ class CLIPProjection : public UnaryBlock {
int64_t out_features;
bool transpose_weight;

void init_params(struct ggml_context* ctx, ggml_type wtype) {
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = tensor_types.find(prefix + "weight") != tensor_types.end() ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
if (transpose_weight) {
LOG_ERROR("transpose_weight");
params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features);
} else {
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
Expand Down Expand Up @@ -861,12 +868,13 @@ struct CLIPTextModelRunner : public GGMLRunner {
CLIPTextModel model;

CLIPTextModelRunner(ggml_backend_t backend,
ggml_type wtype,
std::map<std::string, enum ggml_type>& tensor_types,
const std::string prefix,
CLIPVersion version = OPENAI_CLIP_VIT_L_14,
int clip_skip_value = 1,
bool with_final_ln = true)
: GGMLRunner(backend, wtype), model(version, clip_skip_value, with_final_ln) {
model.init(params_ctx, wtype);
: GGMLRunner(backend), model(version, clip_skip_value, with_final_ln) {
model.init(params_ctx, tensor_types, prefix);
}

std::string get_desc() {
Expand Down Expand Up @@ -908,13 +916,13 @@ struct CLIPTextModelRunner : public GGMLRunner {
struct ggml_tensor* embeddings = NULL;

if (num_custom_embeddings > 0 && custom_embeddings_data != NULL) {
auto custom_embeddings = ggml_new_tensor_2d(compute_ctx,
wtype,
model.hidden_size,
num_custom_embeddings);
auto token_embed_weight = model.get_token_embed_weight();
auto custom_embeddings = ggml_new_tensor_2d(compute_ctx,
token_embed_weight->type,
model.hidden_size,
num_custom_embeddings);
set_backend_tensor_data(custom_embeddings, custom_embeddings_data);

auto token_embed_weight = model.get_token_embed_weight();
// concatenate custom embeddings
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
}
Expand Down
14 changes: 9 additions & 5 deletions common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,11 @@ class GEGLU : public GGMLBlock {
int64_t dim_in;
int64_t dim_out;

void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
params["proj.bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim_out * 2);
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
enum ggml_type wtype = (tensor_types.find(prefix + "proj.weight") != tensor_types.end()) ? tensor_types[prefix + "proj.weight"] : GGML_TYPE_F32;
enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "proj.bias") != tensor_types.end()) ? tensor_types[prefix + "proj.bias"] : GGML_TYPE_F32;
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
}

public:
Expand Down Expand Up @@ -438,8 +440,10 @@ class SpatialTransformer : public GGMLBlock {

class AlphaBlender : public GGMLBlock {
protected:
void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["mix_factor"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
// Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32;
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
}

float get_alpha() {
Expand Down
44 changes: 20 additions & 24 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
SDVersion version = VERSION_SD1;
PMVersion pm_version = PM_VERSION_1;
CLIPTokenizer tokenizer;
ggml_type wtype;
std::shared_ptr<CLIPTextModelRunner> text_model;
std::shared_ptr<CLIPTextModelRunner> text_model2;

Expand All @@ -57,25 +56,25 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::vector<std::string> readed_embeddings;

FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
ggml_type wtype,
std::map<std::string, enum ggml_type>& tensor_types,
const std::string& embd_dir,
SDVersion version = VERSION_SD1,
PMVersion pv = PM_VERSION_1,
int clip_skip = -1)
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir) {
if (clip_skip <= 0) {
clip_skip = 1;
if (version == VERSION_SD2 || version == VERSION_SDXL) {
clip_skip = 2;
}
}
if (version == VERSION_SD1) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip);
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip);
} else if (version == VERSION_SD2) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_H_14, clip_skip);
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip);
} else if (version == VERSION_SDXL) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
}
}

Expand Down Expand Up @@ -138,14 +137,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
return false;
}
embd = ggml_new_tensor_2d(embd_ctx, wtype, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
embd = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
*dst_tensor = embd;
return true;
};
model_loader.load_tensors(on_load, NULL);
readed_embeddings.push_back(embd_name);
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(wtype)),
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(embd->type)),
embd->data,
ggml_nbytes(embd));
for (int i = 0; i < embd->ne[1]; i++) {
Expand Down Expand Up @@ -590,9 +589,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
struct FrozenCLIPVisionEmbedder : public GGMLRunner {
CLIPVisionModelProjection vision_model;

FrozenCLIPVisionEmbedder(ggml_backend_t backend, ggml_type wtype)
: vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend, wtype) {
vision_model.init(params_ctx, wtype);
FrozenCLIPVisionEmbedder(ggml_backend_t backend, std::map<std::string, enum ggml_type>& tensor_types)
: vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend) {
vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer");
}

std::string get_desc() {
Expand Down Expand Up @@ -627,7 +626,6 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
};

struct SD3CLIPEmbedder : public Conditioner {
ggml_type wtype;
CLIPTokenizer clip_l_tokenizer;
CLIPTokenizer clip_g_tokenizer;
T5UniGramTokenizer t5_tokenizer;
Expand All @@ -636,15 +634,15 @@ struct SD3CLIPEmbedder : public Conditioner {
std::shared_ptr<T5Runner> t5;

SD3CLIPEmbedder(ggml_backend_t backend,
ggml_type wtype,
std::map<std::string, enum ggml_type>& tensor_types,
int clip_skip = -1)
: wtype(wtype), clip_g_tokenizer(0) {
: clip_g_tokenizer(0) {
if (clip_skip <= 0) {
clip_skip = 2;
}
clip_l = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
clip_g = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
t5 = std::make_shared<T5Runner>(backend, wtype);
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
}

void set_clip_skip(int clip_skip) {
Expand Down Expand Up @@ -974,21 +972,19 @@ struct SD3CLIPEmbedder : public Conditioner {
};

struct FluxCLIPEmbedder : public Conditioner {
ggml_type wtype;
CLIPTokenizer clip_l_tokenizer;
T5UniGramTokenizer t5_tokenizer;
std::shared_ptr<CLIPTextModelRunner> clip_l;
std::shared_ptr<T5Runner> t5;

FluxCLIPEmbedder(ggml_backend_t backend,
ggml_type wtype,
int clip_skip = -1)
: wtype(wtype) {
std::map<std::string, enum ggml_type>& tensor_types,
int clip_skip = -1) {
if (clip_skip <= 0) {
clip_skip = 2;
}
clip_l = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, true);
t5 = std::make_shared<T5Runner>(backend, wtype);
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, true);
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
}

void set_clip_skip(int clip_skip) {
Expand Down
6 changes: 3 additions & 3 deletions control.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,10 @@ struct ControlNet : public GGMLRunner {
bool guided_hint_cached = false;

ControlNet(ggml_backend_t backend,
ggml_type wtype,
std::map<std::string, enum ggml_type>& tensor_types,
SDVersion version = VERSION_SD1)
: GGMLRunner(backend, wtype), control_net(version) {
control_net.init(params_ctx, wtype);
: GGMLRunner(backend), control_net(version) {
control_net.init(params_ctx, tensor_types, "");
}

~ControlNet() {
Expand Down
16 changes: 7 additions & 9 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ struct UNetModel : public DiffusionModel {
UNetModelRunner unet;

UNetModel(ggml_backend_t backend,
ggml_type wtype,
std::map<std::string, enum ggml_type>& tensor_types,
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: unet(backend, wtype, version, flash_attn) {
: unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
}

void alloc_params_buffer() {
Expand Down Expand Up @@ -83,9 +83,8 @@ struct MMDiTModel : public DiffusionModel {
MMDiTRunner mmdit;

MMDiTModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_SD3_2B)
: mmdit(backend, wtype, version) {
std::map<std::string, enum ggml_type>& tensor_types)
: mmdit(backend, tensor_types, "model.diffusion_model") {
}

void alloc_params_buffer() {
Expand Down Expand Up @@ -133,10 +132,9 @@ struct FluxModel : public DiffusionModel {
Flux::FluxRunner flux;

FluxModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_FLUX_DEV,
bool flash_attn = false)
: flux(backend, wtype, version, flash_attn) {
std::map<std::string, enum ggml_type>& tensor_types,
bool flash_attn = false)
: flux(backend, tensor_types, "model.diffusion_model", flash_attn) {
}

void alloc_params_buffer() {
Expand Down
7 changes: 3 additions & 4 deletions esrgan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,9 @@ struct ESRGAN : public GGMLRunner {
int scale = 4;
int tile_size = 128; // avoid cuda OOM for 4gb VRAM

ESRGAN(ggml_backend_t backend,
ggml_type wtype)
: GGMLRunner(backend, wtype) {
rrdb_net.init(params_ctx, wtype);
ESRGAN(ggml_backend_t backend, std::map<std::string, enum ggml_type>& tensor_types)
: GGMLRunner(backend) {
rrdb_net.init(params_ctx, tensor_types, "");
}

std::string get_desc() {
Expand Down
3 changes: 1 addition & 2 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,8 +1010,7 @@ int main(int argc, const char* argv[]) {
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) {
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
params.n_threads,
params.wtype);
params.n_threads);

if (upscaler_ctx == NULL) {
printf("new_upscaler_ctx failed\n");
Expand Down
Loading
Loading