Skip to content

Commit 1b1ee17

Browse files
committed
Port VQ-diffusion text embeddings (CLIP)
Port the text embeddings (CLIP) for the ITHQ dataset `convert_vq_diffusion_to_diffusers.py` script now uses transformers to pull CLIP and save it along with the rest of the model. Note that in VQ-diffusion, the output text embeddings are additionally normalized. The additional normalization will be added to the pipeline as a part of inference in a later commit.
1 parent a0be2a1 commit 1b1ee17

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

scripts/convert_vq_diffusion_to_diffusers.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from diffusers import VQModel
3636
from diffusers.models.vq_diffusion_attention import VQDiffusionTransformer
3737
from diffusers.pipelines import VQDiffusionPipeline
38+
from transformers import CLIPTextModel, CLIPTokenizer
3839
from yaml.loader import FullLoader
3940

4041

@@ -819,9 +820,37 @@ def read_config_file(filename):
819820

820821
# done transformer_model
821822

823+
# text encoder
824+
825+
print("loading CLIP text encoder")
826+
827+
clip_name = "openai/clip-vit-base-patch32"
828+
829+
# The original VQ-Diffusion specifies the pad value by the int used in the
830+
# returned tokens. Each model uses `0` as the pad value. The transformers clip api
831+
# specifies the pad value via the token before it has been tokenized. The `!` pad
832+
# token is the same as padding with the `0` pad value.
833+
pad_token = "!"
834+
835+
tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")
836+
837+
assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0
838+
839+
text_encoder_model = CLIPTextModel.from_pretrained(
840+
clip_name,
841+
# `CLIPTextModel` does not support device_map="auto"
842+
# device_map="auto"
843+
)
844+
845+
print("done loading CLIP text encoder")
846+
847+
# done text encoder
848+
822849
print(f"saving VQ diffusion model, path: {args.dump_path}")
823850

824-
pipe = VQDiffusionPipeline(vqvae=vqvae_model, transformer=transformer_model)
851+
pipe = VQDiffusionPipeline(
852+
vqvae=vqvae_model, transformer=transformer_model, tokenizer=tokenizer_model, text_encoder=text_encoder_model
853+
)
825854
pipe.save_pretrained(args.dump_path)
826855

827856
print("done writing VQ diffusion model")

src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from diffusers import VQDiffusionTransformer, VQModel
2+
from transformers import CLIPTextModel, CLIPTokenizer
23

34
from ...pipeline_utils import DiffusionPipeline
45

@@ -14,6 +15,17 @@ class VQDiffusionPipeline(DiffusionPipeline):
1415
vqvae: VQModel
1516
transformer: VQDiffusionTransformer
1617

17-
def __init__(self, vqvae: VQModel, transformer: VQDiffusionTransformer):
18+
def __init__(
19+
self,
20+
vqvae: VQModel,
21+
transformer: VQDiffusionTransformer,
22+
text_encoder: CLIPTextModel,
23+
tokenizer: CLIPTokenizer,
24+
):
1825
super().__init__()
19-
self.register_modules(vqvae=vqvae, transformer=transformer)
26+
self.register_modules(
27+
vqvae=vqvae,
28+
transformer=transformer,
29+
text_encoder=text_encoder,
30+
tokenizer=tokenizer,
31+
)

0 commit comments

Comments
 (0)