Skip to content

Commit 52b67c5

Browse files
authored
feat: add flux2 support (#1016)
* add flux2 support * rename qwenvl to llm * add Flux2FlowDenoiser * update docs
1 parent 2034588 commit 52b67c5

21 files changed

+489707
-577
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ API and command-line option may change frequently.***
3737
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
3838
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
3939
- [SD3/SD3.5](./docs/sd3.md)
40-
- [Flux-dev/Flux-schnell](./docs/flux.md)
40+
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
41+
- [FLUX.2-dev](./docs/flux2.md)
4142
- [Chroma](./docs/chroma.md)
4243
- [Chroma1-Radiance](./docs/chroma_radiance.md)
4344
- [Qwen Image](./docs/qwen_image.md)
@@ -118,7 +119,8 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
118119

119120
- [SD1.x/SD2.x/SDXL](./docs/sd.md)
120121
- [SD3/SD3.5](./docs/sd3.md)
121-
- [Flux-dev/Flux-schnell](./docs/flux.md)
122+
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
123+
- [FLUX.2-dev](./docs/flux2.md)
122124
- [FLUX.1-Kontext-dev](./docs/kontext.md)
123125
- [Chroma](./docs/chroma.md)
124126
- [🔥Qwen Image](./docs/qwen_image.md)

assets/flux2/example.png

556 KB
Loading

conditioner.hpp

Lines changed: 92 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#define __CONDITIONER_HPP__
33

44
#include "clip.hpp"
5-
#include "qwenvl.hpp"
5+
#include "llm.hpp"
66
#include "t5.hpp"
77

88
struct SDCondition {
@@ -1623,61 +1623,72 @@ struct T5CLIPEmbedder : public Conditioner {
16231623
}
16241624
};
16251625

1626-
struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
1627-
Qwen::Qwen2Tokenizer tokenizer;
1628-
std::shared_ptr<Qwen::Qwen2_5_VLRunner> qwenvl;
1629-
1630-
Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
1631-
bool offload_params_to_cpu,
1632-
const String2TensorStorage& tensor_storage_map = {},
1633-
const std::string prefix = "",
1634-
bool enable_vision = false) {
1635-
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend,
1636-
offload_params_to_cpu,
1637-
tensor_storage_map,
1638-
"text_encoders.qwen2vl",
1639-
enable_vision);
1626+
struct LLMEmbedder : public Conditioner {
1627+
SDVersion version;
1628+
std::shared_ptr<LLM::BPETokenizer> tokenizer;
1629+
std::shared_ptr<LLM::LLMRunner> llm;
1630+
1631+
LLMEmbedder(ggml_backend_t backend,
1632+
bool offload_params_to_cpu,
1633+
const String2TensorStorage& tensor_storage_map = {},
1634+
SDVersion version = VERSION_QWEN_IMAGE,
1635+
const std::string prefix = "",
1636+
bool enable_vision = false)
1637+
: version(version) {
1638+
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
1639+
if (sd_version_is_flux2(version)) {
1640+
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
1641+
}
1642+
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
1643+
tokenizer = std::make_shared<LLM::MistralTokenizer>();
1644+
} else {
1645+
tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
1646+
}
1647+
llm = std::make_shared<LLM::LLMRunner>(arch,
1648+
backend,
1649+
offload_params_to_cpu,
1650+
tensor_storage_map,
1651+
"text_encoders.llm",
1652+
enable_vision);
16401653
}
16411654

16421655
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) override {
1643-
qwenvl->get_param_tensors(tensors, "text_encoders.qwen2vl");
1656+
llm->get_param_tensors(tensors, "text_encoders.llm");
16441657
}
16451658

16461659
void alloc_params_buffer() override {
1647-
qwenvl->alloc_params_buffer();
1660+
llm->alloc_params_buffer();
16481661
}
16491662

16501663
void free_params_buffer() override {
1651-
qwenvl->free_params_buffer();
1664+
llm->free_params_buffer();
16521665
}
16531666

16541667
size_t get_params_buffer_size() override {
16551668
size_t buffer_size = 0;
1656-
buffer_size += qwenvl->get_params_buffer_size();
1669+
buffer_size += llm->get_params_buffer_size();
16571670
return buffer_size;
16581671
}
16591672

16601673
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
1661-
if (qwenvl) {
1662-
qwenvl->set_weight_adapter(adapter);
1674+
if (llm) {
1675+
llm->set_weight_adapter(adapter);
16631676
}
16641677
}
16651678

16661679
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
1667-
size_t max_length = 0,
1668-
size_t system_prompt_length = 0,
1669-
bool padding = false) {
1680+
std::pair<int, int> attn_range,
1681+
size_t max_length = 0,
1682+
bool padding = false) {
16701683
std::vector<std::pair<std::string, float>> parsed_attention;
1671-
if (system_prompt_length > 0) {
1672-
parsed_attention.emplace_back(text.substr(0, system_prompt_length), 1.f);
1673-
auto new_parsed_attention = parse_prompt_attention(text.substr(system_prompt_length, text.size() - system_prompt_length));
1684+
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
1685+
if (attn_range.second - attn_range.first > 0) {
1686+
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
16741687
parsed_attention.insert(parsed_attention.end(),
16751688
new_parsed_attention.begin(),
16761689
new_parsed_attention.end());
1677-
} else {
1678-
parsed_attention = parse_prompt_attention(text);
16791690
}
1680-
1691+
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
16811692
{
16821693
std::stringstream ss;
16831694
ss << "[";
@@ -1693,12 +1704,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
16931704
for (const auto& item : parsed_attention) {
16941705
const std::string& curr_text = item.first;
16951706
float curr_weight = item.second;
1696-
std::vector<int> curr_tokens = tokenizer.tokenize(curr_text, nullptr);
1707+
std::vector<int> curr_tokens = tokenizer->tokenize(curr_text, nullptr);
16971708
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
16981709
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
16991710
}
17001711

1701-
tokenizer.pad_tokens(tokens, weights, max_length, padding);
1712+
tokenizer->pad_tokens(tokens, weights, max_length, padding);
17021713

17031714
// for (int i = 0; i < tokens.size(); i++) {
17041715
// std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
@@ -1713,9 +1724,10 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
17131724
const ConditionerParams& conditioner_params) override {
17141725
std::string prompt;
17151726
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
1716-
size_t system_prompt_length = 0;
1727+
std::pair<int, int> prompt_attn_range;
17171728
int prompt_template_encode_start_idx = 34;
1718-
if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) {
1729+
std::set<int> out_layers;
1730+
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
17191731
LOG_INFO("QwenImageEditPlusPipeline");
17201732
prompt_template_encode_start_idx = 64;
17211733
int image_embed_idx = 64 + 6;
@@ -1727,7 +1739,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
17271739

17281740
for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
17291741
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
1730-
double factor = qwenvl->params.vision.patch_size * qwenvl->params.vision.spatial_merge_size;
1742+
double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
17311743
int height = image.height;
17321744
int width = image.width;
17331745
int h_bar = static_cast<int>(std::round(height / factor)) * factor;
@@ -1757,7 +1769,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
17571769
resized_image.data = nullptr;
17581770

17591771
ggml_tensor* image_embed = nullptr;
1760-
qwenvl->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
1772+
llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
17611773
image_embeds.emplace_back(image_embed_idx, image_embed);
17621774
image_embed_idx += 1 + image_embed->ne[1] + 6;
17631775

@@ -1771,17 +1783,37 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
17711783
}
17721784

17731785
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
1774-
1775-
system_prompt_length = prompt.size();
1776-
17771786
prompt += img_prompt;
1787+
1788+
prompt_attn_range.first = prompt.size();
17781789
prompt += conditioner_params.text;
1790+
prompt_attn_range.second = prompt.size();
1791+
17791792
prompt += "<|im_end|>\n<|im_start|>assistant\n";
1793+
} else if (sd_version_is_flux2(version)) {
1794+
prompt_template_encode_start_idx = 0;
1795+
out_layers = {10, 20, 30};
1796+
1797+
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
1798+
1799+
prompt_attn_range.first = prompt.size();
1800+
prompt += conditioner_params.text;
1801+
prompt_attn_range.second = prompt.size();
1802+
1803+
prompt += "[/INST]";
17801804
} else {
1781-
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n";
1805+
prompt_template_encode_start_idx = 34;
1806+
1807+
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
1808+
1809+
prompt_attn_range.first = prompt.size();
1810+
prompt += conditioner_params.text;
1811+
prompt_attn_range.second = prompt.size();
1812+
1813+
prompt += "<|im_end|>\n<|im_start|>assistant\n";
17821814
}
17831815

1784-
auto tokens_and_weights = tokenize(prompt, 0, system_prompt_length, false);
1816+
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
17851817
auto& tokens = std::get<0>(tokens_and_weights);
17861818
auto& weights = std::get<1>(tokens_and_weights);
17871819

@@ -1790,11 +1822,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
17901822

17911823
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
17921824

1793-
qwenvl->compute(n_threads,
1794-
input_ids,
1795-
image_embeds,
1796-
&hidden_states,
1797-
work_ctx);
1825+
llm->compute(n_threads,
1826+
input_ids,
1827+
image_embeds,
1828+
out_layers,
1829+
&hidden_states,
1830+
work_ctx);
17981831
{
17991832
auto tensor = hidden_states;
18001833
float original_mean = ggml_ext_tensor_mean(tensor);
@@ -1813,14 +1846,25 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
18131846

18141847
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
18151848

1849+
int64_t zero_pad_len = 0;
1850+
if (sd_version_is_flux2(version)) {
1851+
int64_t min_length = 512;
1852+
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
1853+
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
1854+
}
1855+
}
1856+
18161857
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
18171858
GGML_TYPE_F32,
18181859
hidden_states->ne[0],
1819-
hidden_states->ne[1] - prompt_template_encode_start_idx,
1860+
hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len,
18201861
hidden_states->ne[2]);
18211862

18221863
ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
1823-
float value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
1864+
float value = 0.f;
1865+
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) {
1866+
value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
1867+
}
18241868
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
18251869
});
18261870

denoiser.hpp

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ struct Denoiser {
356356
virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0;
357357
virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;
358358

359-
virtual std::vector<float> get_sigmas(uint32_t n, scheduler_t scheduler_type, SDVersion version) {
359+
virtual std::vector<float> get_sigmas(uint32_t n, int /*image_seq_len*/, scheduler_t scheduler_type, SDVersion version) {
360360
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
361361
std::shared_ptr<SigmaScheduler> scheduler;
362362
switch (scheduler_type) {
@@ -582,10 +582,14 @@ struct FluxFlowDenoiser : public Denoiser {
582582
set_parameters(shift);
583583
}
584584

585-
void set_parameters(float shift = 1.15f) {
585+
void set_shift(float shift) {
586586
this->shift = shift;
587-
for (int i = 1; i < TIMESTEPS + 1; i++) {
588-
sigmas[i - 1] = t_to_sigma(i / TIMESTEPS * TIMESTEPS);
587+
}
588+
589+
void set_parameters(float shift) {
590+
set_shift(shift);
591+
for (int i = 0; i < TIMESTEPS; i++) {
592+
sigmas[i] = t_to_sigma(i);
589593
}
590594
}
591595

@@ -627,6 +631,38 @@ struct FluxFlowDenoiser : public Denoiser {
627631
}
628632
};
629633

634+
struct Flux2FlowDenoiser : public FluxFlowDenoiser {
635+
Flux2FlowDenoiser() = default;
636+
637+
float compute_empirical_mu(uint32_t n, int image_seq_len) {
638+
const float a1 = 8.73809524e-05f;
639+
const float b1 = 1.89833333f;
640+
const float a2 = 0.00016927f;
641+
const float b2 = 0.45666666f;
642+
643+
if (image_seq_len > 4300) {
644+
float mu = a2 * image_seq_len + b2;
645+
return mu;
646+
}
647+
648+
float m_200 = a2 * image_seq_len + b2;
649+
float m_10 = a1 * image_seq_len + b1;
650+
651+
float a = (m_200 - m_10) / 190.0f;
652+
float b = m_200 - 200.0f * a;
653+
float mu = a * n + b;
654+
655+
return mu;
656+
}
657+
658+
std::vector<float> get_sigmas(uint32_t n, int image_seq_len, scheduler_t scheduler_type, SDVersion version) override {
659+
float mu = compute_empirical_mu(n, image_seq_len);
660+
LOG_DEBUG("Flux2FlowDenoiser: set shift to %.3f", mu);
661+
set_shift(mu);
662+
return Denoiser::get_sigmas(n, image_seq_len, scheduler_type, version);
663+
}
664+
};
665+
630666
typedef std::function<ggml_tensor*(ggml_tensor*, float, int)> denoise_cb_t;
631667

632668
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t

docs/flux2.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# How to Use
2+
3+
## Download weights
4+
5+
- Download FLUX.2-dev
6+
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
7+
- Download vae
8+
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
9+
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
10+
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
11+
12+
## Examples
13+
14+
```
15+
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu
16+
```
17+
18+
<img alt="flux2 example" src="../assets/flux2/example.png" />
19+
20+
21+

docs/qwen_image.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
## Examples
1515

1616
```
17-
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线: 探索视觉生成基础模型的极限,开创理解与生成一体化的未来。二、Qwen-Image的模型特色:1、复杂文字渲染。支持中英渲染、自动布局; 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景:赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --diffusion-fa --flow-shift 3
17+
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线: 探索视觉生成基础模型的极限,开创理解与生成一体化的未来。二、Qwen-Image的模型特色:1、复杂文字渲染。支持中英渲染、自动布局; 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景:赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --diffusion-fa --flow-shift 3
1818
```
1919

2020
<img alt="qwen example" src="../assets/qwen/example.png" />

0 commit comments

Comments
 (0)